From 1fb536546ab702b56c894812ff4d77c7f03035f8 Mon Sep 17 00:00:00 2001
From: Pauli Virtanen <pauli.t.virtanen@jyu.fi>
Date: Wed, 21 Sep 2022 14:28:47 +0300
Subject: [PATCH] solver: set omega when evaluating currents in Result

---
 usadelndsoc/solver.py | 38 +++++++++++++++++++++++++++++---------
 1 file changed, 29 insertions(+), 9 deletions(-)

diff --git a/usadelndsoc/solver.py b/usadelndsoc/solver.py
index 820f4bf..2d75d1e 100644
--- a/usadelndsoc/solver.py
+++ b/usadelndsoc/solver.py
@@ -498,8 +498,8 @@ class Solver:
         _log_solve.info(f"omega = {omega}")
         nx, ny = self.shape
         Phi00 = np.zeros_like(self._Phi)
-        self._core.fix_terminals(Phi00)
         self._core.omega = omega
+        self._core.fix_terminals(Phi00)
         self._Phi = self._solve(
             self._eval_rhs,
             self._eval_jac_mul,
@@ -651,15 +651,39 @@ class Result:
             G = self.G
             return s * np.linalg.solve(I + gt @ g, I - gt @ g)
 
-    def _get_J(self, Phi):
-        return self._core.grad_A(Phi).transpose(0, 1, 2, 4, 3) * (1 / 16)
+    def _get_J(self, omega, Phi):
+        old_omega = self._core.omega
+        self._core.omega = omega
+        try:
+            return self._core.grad_A(Phi).transpose(0, 1, 2, 4, 3) * (1 / 16)
+        finally:
+            self._core.omega = old_omega
+
+    def _get_S(self, omega, Phi):
+        old_omega = self._core.omega
+        self._core.omega = omega
+        try:
+            return self._core.grad_Omega(Phi).transpose(0, 1, 3, 2)
+        finally:
+            self._core.omega = old_omega
+
+    @property
+    def S(self):
+        if np.asarray(self.omega).ndim == 0:
+            return self._get_J(self.omega, self.Phi)
+        else:
+            return np.array(
+                [self._get_S(w, P) for w, P in zip(self.omega, self.Phi)], dtype=complex
+            )
 
     @property
     def J(self):
         if np.asarray(self.omega).ndim == 0:
-            return self._get_J(self.Phi)
+            return self._get_J(self.omega, self.Phi)
         else:
-            return np.array([self._get_J(P) for P in self.Phi])
+            return np.array(
+                [self._get_J(w, P) for w, P in zip(self.omega, self.Phi)], dtype=complex
+            )
 
     @property
     def J_c(self):
@@ -673,10 +697,6 @@ class Result:
         jz = tr(self.J @ np.kron(s0, sz))
         return np.array([jx, jy, jz]).transpose(1, 2, 3, 0)
 
-    @property
-    def S(self):
-        return self._core.grad_Omega(self.Phi).transpose(0, 1, 3, 2)
-
 
 class UmfpackILU(LinearOperator):
     def __init__(self, M, drop_tol=1e-4):
-- 
GitLab