Skip to content
Snippets Groups Projects
Commit aeb99be9 authored by patavirt's avatar patavirt
Browse files

solver: better omega ordering + don't recompute jac if too slow for now

parent 9ebe4053
No related branches found
No related tags found
No related merge requests found
......@@ -9,6 +9,7 @@ import collections
import hashlib
import tempfile
import pickle
import time
from scipy.sparse import coo_matrix
from scipy.sparse.linalg import (
splu,
......@@ -110,6 +111,8 @@ class Solver:
self._core = Core()
self.set_shape(nx, ny)
self._M = None
self._t_jac = 0.0
self._t_solve = 0.0
def set_shape(self, nx, ny=1):
self._core.set_shape(nx, ny)
......@@ -196,6 +199,10 @@ class Solver:
Phi = np.zeros_like(Phi)
M = self._M = None
if M is None:
self._t_jac = 0.0
self._t_solve = 0.0
self._core.fix_terminals(Phi)
for j in range(maxiter):
......@@ -205,6 +212,8 @@ class Solver:
elif M is None and preconditioner != "none":
# Update preconditioner
_log_solve.debug(" Update Jacobian")
first_jac = not self._core.hess_computed
t0 = time.time()
J = eval_jac(Phi)
_log_solve.debug(" Form preconditioner")
......@@ -213,6 +222,8 @@ class Solver:
J = None
_log_solve.debug(" Preconditioner formed")
frozen_M = False
self._t_jac = time.time() - t0 if not first_jac else 0.0
self._t_solve = 0.0
else:
# Don't update preconditioner
frozen_M = True
......@@ -254,6 +265,8 @@ class Solver:
l_inner_m = 15
l_outer_k = 5
t0 = time.time()
if solver == "cg":
dx, info = cg(
A,
......@@ -309,7 +322,9 @@ class Solver:
else:
raise ValueError(f"Unknown {solver=}")
if count >= l_maxiter * l_inner_m // 2:
self._t_solve += time.time() - t0
if count >= l_maxiter * l_inner_m // 2 and self._t_solve > 2 * self._t_jac:
M = None
# Update
......@@ -529,23 +544,29 @@ class Solver:
return M
def solve_many(self, omega, **kw):
omega = np.asarray(omega)
omega0 = np.asarray(omega)
omega = omega0.ravel()
Phi = np.zeros(omega.shape + self.Phi.shape, dtype=self.Phi.dtype)
js = np.lexsort((-abs(omega), np.sign(omega.real)))
self._t_jac = 0.0
self._t_solve = 0.0
self._M = None
self._Phi[...] = 0
self._core.fix_terminals(Phi)
with np.nditer(omega, ["multi_index"]) as it:
for v in it:
self.solve(omega=v, **kw)
Phi_prev = self._Phi.copy()
Phi[it.multi_index + (Ellipsis,)] = self._Phi
if np.isnan(self._Phi).any():
self._Phi[...] = Phi_prev
for j in js:
self.solve(omega=omega[j], **kw)
Phi_prev = self._Phi.copy()
Phi[j] = self._Phi
if np.isnan(self._Phi).any():
self._Phi[...] = Phi_prev
Phi = Phi.reshape(omega0.shape + Phi.shape[1:])
return Result(self, omega=omega, Phi=Phi)
return Result(self, omega=omega0, Phi=Phi)
def self_consistency(self, T, T_c0, continuation=False, **kw):
nx, ny = self.shape
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment