From aeb99be984185f842585980f7210e897c39d12bc Mon Sep 17 00:00:00 2001
From: Pauli Virtanen <pauli.t.virtanen@jyu.fi>
Date: Wed, 14 Sep 2022 13:32:59 +0300
Subject: [PATCH] solver: better omega ordering + don't recompute jac if too
 slow for now

---
 usadelndsoc/solver.py | 41 +++++++++++++++++++++++++++++++----------
 1 file changed, 31 insertions(+), 10 deletions(-)

diff --git a/usadelndsoc/solver.py b/usadelndsoc/solver.py
index 1b4eebe..c065b1b 100644
--- a/usadelndsoc/solver.py
+++ b/usadelndsoc/solver.py
@@ -9,6 +9,7 @@ import collections
 import hashlib
 import tempfile
 import pickle
+import time
 from scipy.sparse import coo_matrix
 from scipy.sparse.linalg import (
     splu,
@@ -110,6 +111,8 @@ class Solver:
         self._core = Core()
         self.set_shape(nx, ny)
         self._M = None
+        self._t_jac = 0.0
+        self._t_solve = 0.0
 
     def set_shape(self, nx, ny=1):
         self._core.set_shape(nx, ny)
@@ -196,6 +199,10 @@ class Solver:
             Phi = np.zeros_like(Phi)
             M = self._M = None
 
+        if M is None:
+            self._t_jac = 0.0
+            self._t_solve = 0.0
+
         self._core.fix_terminals(Phi)
 
         for j in range(maxiter):
@@ -205,6 +212,8 @@ class Solver:
             elif M is None and preconditioner != "none":
                 # Update preconditioner
                 _log_solve.debug("    Update Jacobian")
+                first_jac = not self._core.hess_computed
+                t0 = time.time()
                 J = eval_jac(Phi)
 
                 _log_solve.debug("    Form preconditioner")
@@ -213,6 +222,8 @@ class Solver:
                 J = None
                 _log_solve.debug("    Preconditioner formed")
                 frozen_M = False
+                self._t_jac = time.time() - t0 if not first_jac else 0.0
+                self._t_solve = 0.0
             else:
                 # Don't update preconditioner
                 frozen_M = True
@@ -254,6 +265,8 @@ class Solver:
             l_inner_m = 15
             l_outer_k = 5
 
+            t0 = time.time()
+
             if solver == "cg":
                 dx, info = cg(
                     A,
@@ -309,7 +322,9 @@ class Solver:
             else:
                 raise ValueError(f"Unknown {solver=}")
 
-            if count >= l_maxiter * l_inner_m // 2:
+            self._t_solve += time.time() - t0
+
+            if count >= l_maxiter * l_inner_m // 2 and self._t_solve > 2 * self._t_jac:
                 M = None
 
             # Update
@@ -529,23 +544,29 @@ class Solver:
         return M
 
     def solve_many(self, omega, **kw):
-        omega = np.asarray(omega)
+        omega0 = np.asarray(omega)
+        omega = omega0.ravel()
 
         Phi = np.zeros(omega.shape + self.Phi.shape, dtype=self.Phi.dtype)
 
+        js = np.lexsort((-abs(omega), np.sign(omega.real)))
+
+        self._t_jac = 0.0
+        self._t_solve = 0.0
         self._M = None
         self._Phi[...] = 0
         self._core.fix_terminals(Phi)
 
-        with np.nditer(omega, ["multi_index"]) as it:
-            for v in it:
-                self.solve(omega=v, **kw)
-                Phi_prev = self._Phi.copy()
-                Phi[it.multi_index + (Ellipsis,)] = self._Phi
-                if np.isnan(self._Phi).any():
-                    self._Phi[...] = Phi_prev
+        for j in js:
+            self.solve(omega=omega[j], **kw)
+            Phi_prev = self._Phi.copy()
+            Phi[j] = self._Phi
+            if np.isnan(self._Phi).any():
+                self._Phi[...] = Phi_prev
+
+        Phi = Phi.reshape(omega0.shape + Phi.shape[1:])
 
-        return Result(self, omega=omega, Phi=Phi)
+        return Result(self, omega=omega0, Phi=Phi)
 
     def self_consistency(self, T, T_c0, continuation=False, **kw):
         nx, ny = self.shape
-- 
GitLab