From 0a28bb92009e7a30da84edc69ec20be18d476790 Mon Sep 17 00:00:00 2001
From: Pauli Virtanen <pauli.t.virtanen@jyu.fi>
Date: Wed, 21 Sep 2022 17:04:16 +0300
Subject: [PATCH] examples: fixup cpr_sns

---
 examples/cpr_sns.py | 8 +++++---
 1 file changed, 5 insertions(+), 3 deletions(-)

diff --git a/examples/cpr_sns.py b/examples/cpr_sns.py
index c57acd5..7624062 100644
--- a/examples/cpr_sns.py
+++ b/examples/cpr_sns.py
@@ -73,7 +73,7 @@ def j(T, h, phi, n=15, eta=0.1, alpha_soc=0.1, L=10, perturbative=False):
     Jx = (J[:, :, 0] + J[:, :, 2]).real / 2
     Jy = (J[:, :, 1] + J[:, :, 3]).real / 2
 
-    return Res(sol.x, sol.y, J, Jx, Jy)
+    return (sol.x, sol.y, J, Jx, Jy)
 
 
 def main():
@@ -86,9 +86,11 @@ def main():
         T, h[:, None], phi[None, :], alpha_soc=alpha_soc, perturbative=False, mem=mem
     )
 
-    Jx_mean = np.asarray([x.Jx[1:-1].sum(axis=1).mean(axis=0) for x in res])
+    Jx_mean = np.asarray(
+        [Res(*x).Jx[1:-1].sum(axis=1).mean(axis=0) for x in res.flat]
+    ).reshape(res.shape)
 
-    plt.plot(phi / pi, Jx_mean)
+    plt.plot(phi / pi, Jx_mean.T)
     plt.xlabel(r"$\varphi / \pi$")
     plt.ylabel(r"$I$")
     plt.legend()
-- 
GitLab