diff --git a/examples/cpr_sns.py b/examples/cpr_sns.py
index 7624062e68aace4a542dddba5bb00e1ca5784d64..4e51eb8d33eabea8c06e289421d6d7fbcee106b1 100644
--- a/examples/cpr_sns.py
+++ b/examples/cpr_sns.py
@@ -56,7 +56,7 @@ Res = collections.namedtuple("Res", ["x", "y", "J", "Jx", "Jy"])
 
 @vectorize_parallel(returns_object=True, noarray=True)
 @usadelndsoc.with_log_level(logging.WARNING)
-def j(T, h, phi, n=15, eta=0.1, alpha_soc=0.1, L=10, perturbative=False):
+def j(T, h, phi, n=15, eta=0.1, alpha_soc=0.1, L=10):
 
     sol = get_solver(soc_alpha=alpha_soc, eta=eta, L=L, n=n, phi=phi, h=h)
 
@@ -78,13 +78,11 @@ def j(T, h, phi, n=15, eta=0.1, alpha_soc=0.1, L=10, perturbative=False):
 
 def main():
     T = 0.1
-    alpha_soc = 0.1
+    alpha_soc = 0.3
     h = np.r_[0.0, 0.25, 0.5, 1.0, 1.5]
     phi = np.linspace(-pi, pi, 37)
 
-    res = j(
-        T, h[:, None], phi[None, :], alpha_soc=alpha_soc, perturbative=False, mem=mem
-    )
+    res = j(T, h[:, None], phi[None, :], alpha_soc=alpha_soc, L=3, mem=mem)
 
     Jx_mean = np.asarray(
         [Res(*x).Jx[1:-1].sum(axis=1).mean(axis=0) for x in res.flat]