diff --git a/examples/cpr_sns.py b/examples/cpr_sns.py
index 5c1a543c751dadd12a809164023c8b231da9f763..c57acd5deeaf34094a35eeed942f403f27a8df95 100644
--- a/examples/cpr_sns.py
+++ b/examples/cpr_sns.py
@@ -46,7 +46,7 @@ def get_solver(soc_alpha, eta, phi, L=10, h=0.5, D=1.0, n=5):
     sol.Ux[...] = expm(1j * dx * Ax)
     sol.Uy[...] = expm(1j * dy * Ay)
 
-    sol.Omega[...] += h * np.kron(S_z, S_y)
+    sol.Omega[1:-1] += h * np.kron(S_z, S_y)
 
     return sol
 
@@ -67,7 +67,7 @@ def j(T, h, phi, n=15, eta=0.1, alpha_soc=0.1, L=10, perturbative=False):
     res = sol.solve_many(omega=w)
 
     t3 = np.diag([1, 1, -1, -1])
-    J = tr(res.J @ t3) * (-1 / 16)
+    J = tr(res.J @ t3)
     J = -1j * pi * (J * a[:, None, None, None]).sum(axis=0)
 
     Jx = (J[:, :, 0] + J[:, :, 2]).real / 2
@@ -79,10 +79,12 @@ def j(T, h, phi, n=15, eta=0.1, alpha_soc=0.1, L=10, perturbative=False):
 def main():
     T = 0.1
     alpha_soc = 0.1
-    h = 0.5
+    h = np.r_[0.0, 0.25, 0.5, 1.0, 1.5]
     phi = np.linspace(-pi, pi, 37)
 
-    res = j(T, h, phi, alpha_soc=alpha_soc, perturbative=False, mem=mem)
+    res = j(
+        T, h[:, None], phi[None, :], alpha_soc=alpha_soc, perturbative=False, mem=mem
+    )
 
     Jx_mean = np.asarray([x.Jx[1:-1].sum(axis=1).mean(axis=0) for x in res])