From aeb99be984185f842585980f7210e897c39d12bc Mon Sep 17 00:00:00 2001 From: Pauli Virtanen <pauli.t.virtanen@jyu.fi> Date: Wed, 14 Sep 2022 13:32:59 +0300 Subject: [PATCH] solver: better omega ordering + don't recompute jac if too slow for now --- usadelndsoc/solver.py | 41 +++++++++++++++++++++++++++++++---------- 1 file changed, 31 insertions(+), 10 deletions(-) diff --git a/usadelndsoc/solver.py b/usadelndsoc/solver.py index 1b4eebe..c065b1b 100644 --- a/usadelndsoc/solver.py +++ b/usadelndsoc/solver.py @@ -9,6 +9,7 @@ import collections import hashlib import tempfile import pickle +import time from scipy.sparse import coo_matrix from scipy.sparse.linalg import ( splu, @@ -110,6 +111,8 @@ class Solver: self._core = Core() self.set_shape(nx, ny) self._M = None + self._t_jac = 0.0 + self._t_solve = 0.0 def set_shape(self, nx, ny=1): self._core.set_shape(nx, ny) @@ -196,6 +199,10 @@ class Solver: Phi = np.zeros_like(Phi) M = self._M = None + if M is None: + self._t_jac = 0.0 + self._t_solve = 0.0 + self._core.fix_terminals(Phi) for j in range(maxiter): @@ -205,6 +212,8 @@ class Solver: elif M is None and preconditioner != "none": # Update preconditioner _log_solve.debug(" Update Jacobian") + first_jac = not self._core.hess_computed + t0 = time.time() J = eval_jac(Phi) _log_solve.debug(" Form preconditioner") @@ -213,6 +222,8 @@ class Solver: J = None _log_solve.debug(" Preconditioner formed") frozen_M = False + self._t_jac = time.time() - t0 if not first_jac else 0.0 + self._t_solve = 0.0 else: # Don't update preconditioner frozen_M = True @@ -254,6 +265,8 @@ class Solver: l_inner_m = 15 l_outer_k = 5 + t0 = time.time() + if solver == "cg": dx, info = cg( A, @@ -309,7 +322,9 @@ class Solver: else: raise ValueError(f"Unknown {solver=}") - if count >= l_maxiter * l_inner_m // 2: + self._t_solve += time.time() - t0 + + if count >= l_maxiter * l_inner_m // 2 and self._t_solve > 2 * self._t_jac: M = None # Update @@ -529,23 +544,29 @@ class Solver: return M def solve_many(self, omega, **kw): - omega = np.asarray(omega) + omega0 = np.asarray(omega) + omega = omega0.ravel() Phi = np.zeros(omega.shape + self.Phi.shape, dtype=self.Phi.dtype) + js = np.lexsort((-abs(omega), np.sign(omega.real))) + + self._t_jac = 0.0 + self._t_solve = 0.0 self._M = None self._Phi[...] = 0 self._core.fix_terminals(Phi) - with np.nditer(omega, ["multi_index"]) as it: - for v in it: - self.solve(omega=v, **kw) - Phi_prev = self._Phi.copy() - Phi[it.multi_index + (Ellipsis,)] = self._Phi - if np.isnan(self._Phi).any(): - self._Phi[...] = Phi_prev + for j in js: + self.solve(omega=omega[j], **kw) + Phi_prev = self._Phi.copy() + Phi[j] = self._Phi + if np.isnan(self._Phi).any(): + self._Phi[...] = Phi_prev + + Phi = Phi.reshape(omega0.shape + Phi.shape[1:]) - return Result(self, omega=omega, Phi=Phi) + return Result(self, omega=omega0, Phi=Phi) def self_consistency(self, T, T_c0, continuation=False, **kw): nx, ny = self.shape -- GitLab