From 41806c7b3ea4a0558b6806d61e1edc074406166e Mon Sep 17 00:00:00 2001
From: Pauli Virtanen <pauli.t.virtanen@jyu.fi>
Date: Wed, 21 Sep 2022 16:32:56 +0300
Subject: [PATCH] examples: cpr_sns: update

---
 examples/cpr_sns.py | 10 ++++++----
 1 file changed, 6 insertions(+), 4 deletions(-)

diff --git a/examples/cpr_sns.py b/examples/cpr_sns.py
index 5c1a543..c57acd5 100644
--- a/examples/cpr_sns.py
+++ b/examples/cpr_sns.py
@@ -46,7 +46,7 @@ def get_solver(soc_alpha, eta, phi, L=10, h=0.5, D=1.0, n=5):
     sol.Ux[...] = expm(1j * dx * Ax)
     sol.Uy[...] = expm(1j * dy * Ay)
 
-    sol.Omega[...] += h * np.kron(S_z, S_y)
+    sol.Omega[1:-1] += h * np.kron(S_z, S_y)
 
     return sol
 
@@ -67,7 +67,7 @@ def j(T, h, phi, n=15, eta=0.1, alpha_soc=0.1, L=10, perturbative=False):
     res = sol.solve_many(omega=w)
 
     t3 = np.diag([1, 1, -1, -1])
-    J = tr(res.J @ t3) * (-1 / 16)
+    J = tr(res.J @ t3)
     J = -1j * pi * (J * a[:, None, None, None]).sum(axis=0)
 
     Jx = (J[:, :, 0] + J[:, :, 2]).real / 2
@@ -79,10 +79,12 @@ def j(T, h, phi, n=15, eta=0.1, alpha_soc=0.1, L=10, perturbative=False):
 def main():
     T = 0.1
     alpha_soc = 0.1
-    h = 0.5
+    h = np.r_[0.0, 0.25, 0.5, 1.0, 1.5]
     phi = np.linspace(-pi, pi, 37)
 
-    res = j(T, h, phi, alpha_soc=alpha_soc, perturbative=False, mem=mem)
+    res = j(
+        T, h[:, None], phi[None, :], alpha_soc=alpha_soc, perturbative=False, mem=mem
+    )
 
     Jx_mean = np.asarray([x.Jx[1:-1].sum(axis=1).mean(axis=0) for x in res])
 
-- 
GitLab