From ec75e78cd5198be445fb2acce27cb3ce2d861c21 Mon Sep 17 00:00:00 2001
From: Pauli Virtanen <pauli.t.virtanen@jyu.fi>
Date: Mon, 29 Apr 2024 15:32:47 +0300
Subject: [PATCH] Change some solver/result attributes to interpolators

---
 doc/conf.py           |   4 +
 tests/test_basic.py   |  22 ++--
 tests/test_solver.py  |   4 +-
 usadelndsoc/solver.py | 261 ++++++++++++++++++++++++++++++++++--------
 4 files changed, 231 insertions(+), 60 deletions(-)

diff --git a/doc/conf.py b/doc/conf.py
index 2dc91c3..a38e42c 100644
--- a/doc/conf.py
+++ b/doc/conf.py
@@ -9,3 +9,7 @@ exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
 html_theme = 'nature'
 
 autodoc_member_order = "groupwise"
+
+autodoc_default_options = {
+    'special-members': '__init__, __call__'
+}
diff --git a/tests/test_basic.py b/tests/test_basic.py
index a5a1c8a..117ff57 100644
--- a/tests/test_basic.py
+++ b/tests/test_basic.py
@@ -38,7 +38,7 @@ def test_solve_dos():
     s = basic_setup()
     E = np.linspace(-3, 3, 101) + 0.05j
     res = s.solve_many(omega=-1j * E)
-    g = res.G[:, 10, 2, 0, 0]
+    g = res.G[10, 2, :, 0, 0]
     g_an = -1j * E / np.sqrt(1 - E**2)
     assert_allclose(g, g_an, rtol=1e-4)
 
@@ -59,10 +59,10 @@ def test_result_S_J0(omega, h, terminals):
 
     vol = s.Lx * s.Ly / (s.shape[0] * s.shape[1]) * 4j
 
-    assert_allclose(S[:, :, :2, :2], vol * res.G)
-    assert_allclose(S[:, :, :2, 2:], vol * res.F)
-    assert_allclose(S[:, :, 2:, :2], vol * res.Fc)
-    assert_allclose(S[:, :, 2:, 2:], vol * res.Gc)
+    assert_allclose(S[:, :, :2, :2], vol * res.G.A)
+    assert_allclose(S[:, :, :2, 2:], vol * res.F.A)
+    assert_allclose(S[:, :, 2:, :2], vol * res.Fc.A)
+    assert_allclose(S[:, :, 2:, 2:], vol * res.Gc.A)
 
     # Check vs. analytic
     print(omega, h)
@@ -73,15 +73,15 @@ def test_result_S_J0(omega, h, terminals):
     Fcan = np.diag([1 / np.sqrt(wp**2 + 1), 1 / np.sqrt(wm**2 + 1)])
     Gcan = -np.diag([wp / np.sqrt(wp**2 + 1), wm / np.sqrt(wm**2 + 1)])
 
-    shp = res.G.shape
+    shp = res.G.A.shape
 
-    assert_allclose(res.G, np.broadcast_to(Gan, shp), rtol=1e-5)
-    assert_allclose(res.F, np.broadcast_to(Fan, shp), rtol=1e-5)
-    assert_allclose(res.Fc, np.broadcast_to(Fcan, shp), rtol=1e-5)
-    assert_allclose(res.Gc, np.broadcast_to(Gcan, shp), rtol=1e-5)
+    assert_allclose(res.G.A, np.broadcast_to(Gan, shp), rtol=1e-5)
+    assert_allclose(res.F.A, np.broadcast_to(Fan, shp), rtol=1e-5)
+    assert_allclose(res.Fc.A, np.broadcast_to(Fcan, shp), rtol=1e-5)
+    assert_allclose(res.Gc.A, np.broadcast_to(Gcan, shp), rtol=1e-5)
 
     # There are no gradients, hence no currents either
-    J = res.J
+    J = res.J.A
     assert_allclose(J, 0, atol=1e-5)
 
 
diff --git a/tests/test_solver.py b/tests/test_solver.py
index 52d00dc..3e258d3 100644
--- a/tests/test_solver.py
+++ b/tests/test_solver.py
@@ -341,8 +341,8 @@ def test_soc_analytic(T, n, alpha_soc):
 
         assert_allclose(Jy, Jy_an, rtol=1e-2)
 
-        Fs.append(res0.F.copy())
-        Gs.append(res0.G.copy())
+        Fs.append(res0.F.A.copy())
+        Gs.append(res0.G.A.copy())
         Phis.append(res0.Phi.copy())
         Js.append(J)
         Jans.append(Jy_an)
diff --git a/usadelndsoc/solver.py b/usadelndsoc/solver.py
index 664b909..a667c3d 100644
--- a/usadelndsoc/solver.py
+++ b/usadelndsoc/solver.py
@@ -56,6 +56,9 @@ __all__ = [
     "S_z",
     "S_0",
     "interpolate_J",
+    "interpolate_S",
+    "Density",
+    "Current",
 ]
 
 _log = logging.getLogger(__name__)
@@ -110,13 +113,26 @@ def _core_property(name, doc=None):
     def fset(self, value):
         setattr(self._core, name, value)
 
-    return property(fget, fset, doc=doc or getattr(Core, name).__doc__)
+    fget.__name__ = name
+    fset.__name__ = f"set_{name}"
+    doc = doc or getattr(Core, name).__doc__
+    return property(fget, fset, doc=doc)
 
 
-def _array_property(name, doc=None):
+def _array_property(name, doc=None, cls=None):
     def fget(self):
-        return getattr(self, name)
-
+        v = getattr(self, name)
+        if cls is not None:
+            return cls(self, v)
+        else:
+            return v
+
+    fget.__name__ = name
+    if doc is None:
+        doc = ""
+    doc = textwrap.dedent(doc)
+    if cls is not None:
+        doc += f"\n\n:Type: :any:`{cls.__name__}`"
     return property(fget, doc=doc)
 
 
@@ -135,6 +151,94 @@ def _norm(z):
         raise ValueError("Invalid data type")
 
 
+class Density:
+    """
+    Density interpolator.
+    """
+
+    def __init__(self, parent, data):
+        self.__parent = parent
+        self.__data = data
+
+    def __getitem__(self, ix):
+        return self.__data[ix]
+
+    @property
+    def A(self):
+        """Raw data array. Shape (nx, ny, ...)"""
+        return self.__data
+
+    def __call__(self, x, y, method="bilinear"):
+        """
+        Interpolate density at given coordinates.
+
+        Parameters
+        ----------
+        x : array, shape (...)
+            x-coordinates of the points to interpolate at
+        y : array, shape x.shape
+            y-coordinates of the points to interpolate at
+        method : {'bilinear'}
+            Interpolation method
+
+        Returns
+        -------
+        S : array, shape x.shape + (...)
+            Interpolated density.
+
+        See Also
+        --------
+        interpolate_S
+        """
+        return interpolate_S(
+            self.__parent.x, self.__parent.y, self.__data, x, y, method
+        )
+
+
+class Current:
+    """
+    Current interpolator.
+    """
+
+    def __init__(self, parent, data):
+        self.__parent = parent
+        self.__data = data
+
+    def __getitem__(self, ix):
+        return self.__data[ix]
+
+    @property
+    def A(self):
+        """Raw data array. Shape (nx, ny, 4(direction), ...)"""
+        return self.__data
+
+    def __call__(self, x, y, method="bilinear"):
+        """
+        Interpolate current at given coordinates.
+
+        Parameters
+        ----------
+        x : array, shape (...)
+            x-coordinates of the points to interpolate at
+        y : array, shape x.shape
+            y-coordinates of the points to interpolate at
+        method : {'bilinear'}
+            Interpolation method
+
+        Returns
+        -------
+        J : array, shape x.shape + (2, ...)
+            Interpolated current x and y components.
+
+        See Also
+        --------
+        interpolate_J
+        """
+        return interpolate_S(
+            self.__parent.x, self.__parent.y, self.__data, x, y, method
+        )
+
+
 class Solver:
     __global_id = 0
 
@@ -175,6 +279,7 @@ class Solver:
     Phi = _array_property(
         "_Phi",
         doc=r"""Current saddle-point solution in Riccati parametrization. Array of complex, shape (nx, ny, 2, 2, 2).""",
+        cls=Density,
     )
     J_tot = _array_property(
         "_J_tot",
@@ -182,12 +287,14 @@ class Solver:
         Matsubara-summed matrix current between cells, :math:`J = -i \pi T\sum_{\omega_n} \mathcal{J}(\omega_n)`.
         Array of complex, shape (nx, ny, 4(direction), 4, 4).
         """,
+        cls=Current,
     )
     S_tot = _array_property(
         "_S_tot",
         doc=r"""Matsubara-summed density, :math:`S = \pi T \sum_{\omega_n} [Q(\omega_n) - \tau_3]`.
         Array of complex, shape (nx, ny, 4, 4).
         """,
+        cls=Density,
     )
 
     def reset(self):
@@ -647,7 +754,7 @@ class Solver:
         omega0 = np.asarray(omega)
         omega = omega0.ravel()
 
-        Phi = np.zeros(omega.shape + self.Phi.shape, dtype=self.Phi.dtype)
+        Phi = np.zeros(omega.shape + self.Phi.A.shape, dtype=self.Phi.A.dtype)
 
         js = np.lexsort((-abs(omega), np.sign(omega.real)))
 
@@ -723,7 +830,7 @@ class Solver:
             mask = self.mask == MASK_NONE
             self.Delta[mask] = 2 * T_c0[mask, None, None] * np.eye(2)
         res = _self_consistency(self, T, T_c0, **kw)
-        self.Delta[...], self.J_tot[...], self.S_tot[...], cx, success = res
+        self.Delta[...], self._J_tot[...], self._S_tot[...], cx, success = res
         return self.Delta, self.J_tot, self.S_tot, cx, success
 
 
@@ -751,72 +858,113 @@ class Result:
         self.Phi = Phi
         self.mask = parent.mask
 
+    def _density(self, v):
+        if np.asarray(self.omega).ndim == 0:
+            r = v
+        else:
+            r = np.moveaxis(v, 0, 2)
+        return Density(self, r)
+
+    def _current(self, v):
+        if np.asarray(self.omega).ndim == 0:
+            r = v
+        else:
+            r = np.moveaxis(v, 0, 3)
+        return Current(self, r)
+
     @property
     def Delta(self):
         """
-        Order parameter. Shape (nx, ny, 2, 2).
+        Order parameter.
+
+        :Type: :any:`Density`
+        :Shape: (nx, ny, 2, 2)
         """
-        return _DeltaArray(self.Omega)
+        return Density(self, _DeltaArray(self.Omega).A)
 
     @property
     def G(self):
         """
-        Green function Nambu 11-block. Shape ([nw,] nx, ny, 2, 2).
+        Green function Nambu 11-block.
+
+        :Type: :any:`Density`
+        :Shape: (nx, ny, [nw,] 2, 2)
         """
         g = self.Phi[..., 0, :, :]
         gt = self.Phi[..., 1, :, :]
         I = np.eye(2)
         s = _bcast_left(np.sign(self.omega), g.shape)
         with np.errstate(invalid="ignore"):
-            return s * np.linalg.solve(I + g @ gt, I - g @ gt)
+            r = s * np.linalg.solve(I + g @ gt, I - g @ gt)
+        return self._density(r)
 
     @property
     def Gc(self):
         """
-        Green function Nambu 22-block. Shape ([nw,] nx, ny, 2, 2).
+        Green function Nambu 22-block.
+
+        :Type: :any:`Density`
+        :Shape: (nx, ny, [nw,] 2, 2)
         """
         g = self.Phi[..., 0, :, :]
         gt = self.Phi[..., 1, :, :]
         I = np.eye(2)
         s = _bcast_left(np.sign(self.omega), g.shape)
         with np.errstate(invalid="ignore"):
-            return -s * np.linalg.solve(I + gt @ g, I - gt @ g)
+            r = -s * np.linalg.solve(I + gt @ g, I - gt @ g)
+        return self._density(r)
 
     @property
     def F(self):
         """
-        Green function Nambu 12-block. Shape ([nw,] nx, ny, 2, 2).
+        Green function Nambu 12-block.
+
+        :Type: :any:`Density`
+        :Shape: (nx, ny, [nw,] 2, 2)
         """
         g = self.Phi[..., 0, :, :]
         gt = self.Phi[..., 1, :, :]
         I = np.eye(2)
         s = _bcast_left(np.sign(self.omega), g.shape)
         with np.errstate(invalid="ignore"):
-            return 2 * s * np.linalg.solve(I + g @ gt, g)
+            r = 2 * s * np.linalg.solve(I + g @ gt, g)
+        return self._density(r)
 
     @property
     def Fc(self):
         """
-        Green function Nambu 21-block. Shape ([nw,] nx, ny, 2, 2).
+        Green function Nambu 21-block.
+
+        :Type: :any:`Density`
+        :Shape: (nx, ny, [nw,] 2, 2)
         """
         g = self.Phi[..., 0, :, :]
         gt = self.Phi[..., 1, :, :]
         I = np.eye(2)
         s = _bcast_left(np.sign(self.omega), g.shape)
         with np.errstate(invalid="ignore"):
-            return 2 * s * np.linalg.solve(I + gt @ g, gt)
+            r = 2 * s * np.linalg.solve(I + gt @ g, gt)
+        return self._density(r)
 
     @property
     def Q(self):
         """
-        Full Green function / Q-matrix. Shape ([nw,] nx, ny, 4, 4).
+        Full Green function / Q-matrix.
+
+        :Type: :any:`Density`
+        :Shape: (nx, ny, [nw,] 4, 4)
         """
-        r = np.zeros(self.shape + (4, 4), dtype=complex)
-        r[..., :2, :2] = self.G
-        r[..., 2:, 2:] = self.Gc
-        r[..., :2, 2:] = self.F
-        r[..., 2:, :2] = self.Fc
-        return r
+        w = np.asarray(self.omega)
+        if w.ndim == 0:
+            r = np.zeros(self.shape + (4, 4), dtype=complex)
+        else:
+            r = np.zeros(self.shape + w.shape + (4, 4), dtype=complex)
+
+        r[..., :2, :2] = self.G.A
+        r[..., 2:, 2:] = self.Gc.A
+        r[..., :2, 2:] = self.F.A
+        r[..., 2:, :2] = self.Fc.A
+        return Density(self, r)
 
     def _get_J(self, omega, Phi):
         old_omega = self._core.omega
@@ -837,19 +985,23 @@ class Result:
     @property
     def S(self):
         """
-        Matrix density, dS/dOmega. Shape ([nw,] nx, ny, 4, 4)
+        Matrix density, dS/dOmega.
+
+        :Type: :any:`Density`
+        :Shape: (nx, ny, [nw,] 4, 4)
         """
         if np.asarray(self.omega).ndim == 0:
-            return self._get_S(self.omega, self.Phi)
+            r = self._get_S(self.omega, self.Phi)
         else:
-            return np.array(
+            r = np.array(
                 [self._get_S(w, P) for w, P in zip(self.omega, self.Phi)], dtype=complex
             )
+        return self._density(r)
 
     @property
     def J(self):
         """
-        Matrix current, dS/dA. Shape ([nw,] nx, ny, 4(direction), 4, 4).
+        Matrix current, dS/dA.
 
         The meaning of the direction index is:
 
@@ -857,37 +1009,52 @@ class Result:
         - UP (1): current exiting cell (x,y) to cell (x,y+1)
         - LEFT (2): current entering cell (x,y) from cell (x-1,y)
         - DOWN (3): current entering cell (x,y) from cell (x,y-1)
+
+        :Type: :any:`Current`
+        :Shape: (nx, ny, 4(direction), [nw,] 4, 4)
         """
         if np.asarray(self.omega).ndim == 0:
-            return self._get_J(self.omega, self.Phi)
+            r = self._get_J(self.omega, self.Phi)
         else:
-            return np.array(
+            r = np.array(
                 [self._get_J(w, P) for w, P in zip(self.omega, self.Phi)], dtype=complex
             )
+        return self._current(r)
 
     @property
     def J_c(self):
         """
-        Spectral charge current. Shape ([nw,] nx, ny, 4(direction)).
+        Spectral charge current.
+
+        Direction index has same meaning as for :any:`J`.
 
-        Direction index has same meaning as for :ref:`J`.
+        :Type: :any:`Current`
+        :Shape: ([nw,] nx, ny, 4(direction))
         """
         t3 = np.diag([1, 1, -1, -1])
-        return tr(self.J @ t3)
+        r = tr(self.J.A @ t3)
+        return Current(self, r)
 
     @property
     def J_s(self):
         """
-        Spectral spin current.  Shape ([nw,] nx, ny, 4(direction), 3(spin-direction)).
+        Spectral spin current.
 
-        The current direction index has same meaning as for :ref:`J`.
+        The current direction index has same meaning as for :any:`J`.
 
         The spin direction index means (x, y, z).
+
+        :Type: :any:`Current`
+        :Shape: (nx, ny, 4(direction), [nw,] 3(spin-direction))
         """
-        jx = tr(self.J @ np.kron(s0, sx))
-        jy = tr(self.J @ np.kron(s0, sy))
-        jz = tr(self.J @ np.kron(s0, sz))
-        return np.array([jx, jy, jz]).transpose(1, 2, 3, 0)
+        jx = tr(self.J.A @ np.kron(s0, sx))
+        jy = tr(self.J.A @ np.kron(s0, sy))
+        jz = tr(self.J.A @ np.kron(s0, sz))
+        if np.asarray(self.omega).ndim == 0:
+            r = np.array([jx, jy, jz]).transpose(1, 2, 3, 0)
+        else:
+            r = np.array([jx, jy, jz]).transpose(1, 2, 3, 4, 0)
+        return Current(self, r)
 
 
 class UmfpackILU(LinearOperator):
@@ -1240,10 +1407,10 @@ def _mp_one(args, solver=None, solver_kw=None):
 
         for wx, ax in zip(w, a):
             res = solver.solve(omega=wx, **solver_kw)
-            r = singlet(solver.Delta) / abs(wx) - singlet(res.F)
+            r = singlet(solver.Delta) / abs(wx) - singlet(res.F.A)
             rtot += (2 * np.pi * ax) * r
-            Jtot += (-2j * np.pi * ax) * res.J
-            Stot += (2 * np.pi * ax) * (res.Q - tau_3)
+            Jtot += (-2j * np.pi * ax) * res.J.A
+            Stot += (2 * np.pi * ax) * (res.Q.A - tau_3)
     finally:
         _log_solve.setLevel(old_level)
 
@@ -1404,7 +1571,7 @@ def cpr(
         state.Deltas = [solver.Delta.copy()]
         state.phis = [0]
         state.dss = [0]
-        state.Js = [solver.J_tot.copy()]
+        state.Js = [solver.J_tot.A.copy()]
 
     phi = state.phis[-1] + ds / 4
 
@@ -1417,7 +1584,7 @@ def cpr(
         )
         state.Deltas.append(solver.Delta.copy())
         state.phis.append(float(phi))
-        state.Js.append(solver.J_tot.copy())
+        state.Js.append(solver.J_tot.A.copy())
 
     for j in range(len(state.phis), max_points):
         if auto_stop and len(state.phis) > 2:
@@ -1469,7 +1636,7 @@ def cpr(
                 return np.asarray(dr)
 
             constraint_x = np.asarray(phi).ravel()
-            Delta, J_tot, cx, success = solver.self_consistency(
+            Delta, J_tot, S_tot, cx, success = solver.self_consistency(
                 T=T,
                 T_c0=T_c0,
                 constraint_x=constraint_x,
@@ -1507,10 +1674,10 @@ def cpr(
         state.dss.append(float(np.squeeze(ds_cur)))
         state.Deltas.append(Delta.copy())
         state.phis.append(float(np.squeeze(phi)))
-        state.Js.append(solver.J_tot.copy())
+        state.Js.append(solver.J_tot.A.copy())
 
         t3 = np.diag([1, 1, -1, -1])
-        J_avg = np.mean(tr(solver.J_tot[:, :, 0].real @ t3))
+        J_avg = np.mean(tr(solver.J_tot.A[:, :, 0].real @ t3))
         _log.info(f"CPR #{j}: phi = {phi}, J = {J_avg}")
 
         if filename is not None:
@@ -1538,7 +1705,7 @@ def interpolate_J(x0, y0, J, x, y, method="bilinear"):
         Cell x-coordinates
     y0 : array, shape (ny,)
         Cell y-coordinates
-    J : array, shape (nx, ny, 4, ...)
+    J : array, shape (nx, ny, 4(direction), ...)
         Matrix current, defined on links between cells.
     x : array
         x-coordinate to interpolate J at.
-- 
GitLab