diff --git a/examples/cpr_sns.py b/examples/cpr_sns.py
index 5f5b0fda4fbc03ce555168b4396c57ccc36436d5..f985c55d449e477becca44771b20096cc46277ad 100644
--- a/examples/cpr_sns.py
+++ b/examples/cpr_sns.py
@@ -95,70 +95,60 @@ def Gamma_to_alpha(Gamma_DP, Gamma_ST):
     return alpha_soc, eta
 
 
-def do(W_xi=6, multin=True):
+def do(W_xi=12, multin=True):
     T = 0.1
     Gamma_DP = 10
     Gamma_ST = 1
     h = -10.0
     phi = np.linspace(0, 2 * pi, 37)
     xi = 1 / np.sqrt(2 * pi)
-    L = np.array([0.1, 1.5, 2, 3]) * xi
+    L = np.array([0.1, 1.5, 2, 2.25, 2.5, 2.75, 3]) * xi
     W = W_xi * xi
 
-    alpha, eta = Gamma_to_alpha(Gamma_DP, Gamma_ST)
+    L0 = np.array([0.1, 1.5, 2, 3]) * xi
+    Lsel = np.array([np.isclose(z, L0).any() for z in L])
 
-    res = j(
-        T, h, phi[None, :], alpha_soc=alpha, eta=eta, L=L[:, None], W=W, n=10, mem=mem
-    )
+    alpha, eta = Gamma_to_alpha(Gamma_DP, Gamma_ST)
 
     if multin:
+        ns = (20, 10, 25)
         mphi = phi[::2]
-        res0 = j(
-            T,
-            h,
-            mphi[None, :],
-            alpha_soc=alpha,
-            eta=eta,
-            L=L[:, None],
-            W=W,
-            n=5,
-            mem=mem,
-        )
-        res1 = j(
-            T,
-            h,
-            mphi[None, :],
-            alpha_soc=alpha,
-            eta=eta,
-            L=L[:, None],
-            W=W,
-            n=20,
-            mem=mem,
+    else:
+        ns = (20,)
+
+    ress = []
+    Jxs = []
+
+    for n in ns:
+        if multin and n > ns[0]:
+            p = mphi
+        else:
+            p = phi
+        res = j(
+            T, h, p[None, :], alpha_soc=alpha, eta=eta, L=L[:, None], W=W, n=n, mem=mem
         )
+        ress.append(res)
 
-    Jx_mean = np.asarray(
-        [Res(*x).Jx[1:-1].sum(axis=1).mean(axis=0) for x in res.flat]
-    ).reshape(res.shape)
-
-    if multin:
-        Jx0_mean = np.asarray(
-            [Res(*x).Jx[1:-1].sum(axis=1).mean(axis=0) for x in res0.flat]
-        ).reshape(res0.shape)
-
-        Jx1_mean = np.asarray(
-            [Res(*x).Jx[1:-1].sum(axis=1).mean(axis=0) for x in res1.flat]
-        ).reshape(res1.shape)
+        Jx = np.asarray(
+            [Res(*x).Jx[1:-1].sum(axis=1).mean(axis=0) for x in res.flat]
+        ).reshape(res.shape)
+        Jxs.append(Jx)
 
     fig, axs = plt.subplots(1, 2, layout="compressed")
 
     ax = axs[0]
-    ax.plot(phi / pi, Jx_mean.T / abs(Jx_mean).max(axis=1))
+    ax.plot(phi / pi, Jxs[0][Lsel, :].T / abs(Jxs[0][Lsel]).max(axis=1))
     if multin:
-        # ax.plot(mphi / pi, Jx0_mean.T / abs(Jx0_mean).max(axis=1), "k:", alpha=0.25)
-        ax.plot(mphi / pi, Jx1_mean.T / abs(Jx1_mean).max(axis=1), "k:")
+        ax.plot(
+            mphi / pi,
+            Jxs[1][Lsel, :].T / abs(Jxs[1][Lsel]).max(axis=1),
+            "k:",
+            alpha=0.25,
+        )
+        ax.plot(mphi / pi, Jxs[2][Lsel, :].T / abs(Jxs[2][Lsel]).max(axis=1), "k:")
     ax.set_xlabel(r"$\varphi / \pi$")
     ax.set_ylabel(r"$I / I_{\mathrm{max}}$")
-    ax.legend(L / xi, title=r"$L/\xi$", loc="lower right")
+    ax.legend(L[Lsel] / xi, title=r"$L/\xi$", loc="lower right")
 
     def eff(Jx):
         Jm = Jx.min(axis=1)
@@ -166,12 +156,14 @@ def do(W_xi=6, multin=True):
         return (abs(Jp) - abs(Jm)) / (abs(Jp) + abs(Jm))
 
     ax = axs[1]
-    ax.plot(L / xi, 100 * eff(Jx_mean))
+    ax.plot(L / xi, 100 * eff(Jxs[0]))
     if multin:
-        # ax.plot(L / xi, 100 * eff(Jx0_mean), "k:", alpha=0.25)
-        ax.plot(L / xi, 100 * eff(Jx1_mean), "k:")
+        ax.plot(L / xi, 100 * eff(Jxs[1]), "k:", alpha=0.25)
+        ax.plot(L / xi, 100 * eff(Jxs[2]), "k:")
     ax.set_xlabel(r"$L / \xi$")
     ax.set_ylabel(r"$\eta$  [%]")
+    if multin:
+        ax.legend(ns)
 
     fig.suptitle(
         rf"$h = {h} \Delta_0$, $T = {T} \Delta_0$, $W = {W/xi} \xi_0$ $\Gamma_{{DP}} = {Gamma_DP} \Delta_0$, $\Gamma_{{ST}} = {Gamma_ST} \Delta_0$   ($\tilde{{\eta}} = {eta:.3g}$, $\tilde{{\alpha}} = {alpha:.3g}$)"