# -*- coding:utf-8; eval: (blacken-mode) -*-
import os
import io
import logging
import numpy as np
import concurrent.futures
import functools
import collections
import hashlib
import tempfile
import pickle
import time
from scipy.sparse import coo_matrix
from scipy.sparse.linalg import (
    splu,
    spilu,
    gcrotmk,
    LinearOperator,
    lgmres,
    gmres,
    cg,
    tfqmr,
)
from scipy.special import zeta
from scipy.linalg import blas, signm

try:
    from scikits import umfpack
except ImportError:
    umfpack = None

try:
    import pyamg
except ImportError:
    pyamg = None

from .core import Core, MASK_NONE, MASK_TERMINAL, MASK_VACUUM, _DeltaArray
from .matsubara import get_matsubara_sum

__all__ = [
    "Solver",
    "Result",
    "cpr",
    "MASK_NONE",
    "MASK_TERMINAL",
    "MASK_VACUUM",
    "tr",
    "LEFT",
    "UP",
    "RIGHT",
    "DOWN",
    "S_x",
    "S_y",
    "S_z",
    "S_0",
]

_log = logging.getLogger(__name__)
_log_solve = logging.getLogger(__name__ + ".solve")


S_x = np.array([[0, 1], [1, 0]])
S_y = np.array([[0, -1j], [1j, 0]])
S_z = np.array([[1, 0], [0, -1]])
S_0 = np.array([[1, 0], [0, 1]])

RIGHT = 0
UP = 1
LEFT = 2
DOWN = 3


def _core_property(name):
    def fget(self):
        return getattr(self._core, name)

    def fset(self, value):
        setattr(self._core, name, value)

    return property(fget, fset)


def _array_property(name):
    def fget(self):
        return getattr(self, name)

    return property(fget)


def _norm(z):
    if z.dtype.char == "D":
        if z.flags.f_contiguous:
            return blas.dznrm2(z.ravel("F"))
        else:
            return blas.dznrm2(z.ravel())
    elif z.dtype.char == "d":
        if z.flags.f_contiguous:
            return blas.dnrm2(z.ravel("F"))
        else:
            return blas.dnrm2(z.ravel())
    else:
        raise ValueError("Invalid data type")


class Solver:
    __global_id = 0

    def __init__(self, nx=1, ny=1):
        Solver.__global_id += 1
        self._id = Solver.__global_id
        self._core = Core()
        self.set_shape(nx, ny)
        self._M = None
        self._t_jac = 0.0
        self._t_solve = 0.0
        self._t_solve_min = np.inf

    def set_shape(self, nx, ny=1):
        self._core.set_shape(nx, ny)
        self._Phi = np.zeros((nx, ny, 2, 2, 2), dtype=np.complex_)
        self._J_tot = np.zeros((nx, ny, 4, 4, 4), dtype=np.complex_)

    shape = _core_property("shape")
    Lx = _core_property("Lx")
    Ly = _core_property("Ly")
    x = _core_property("x")
    y = _core_property("y")
    mask = _core_property("mask")
    Omega = _core_property("Omega")
    omega = _core_property("omega")
    eta = _core_property("eta")
    D = _core_property("D")
    Delta = _core_property("Delta")
    U = _core_property("U")
    Ux = _core_property("Ux")
    Uy = _core_property("Uy")

    Phi = _array_property("_Phi")
    J_tot = _array_property("_J_tot")

    def reset(self):
        Solver.__global_id += 1
        self._id = Solver.__global_id
        self._core.reset()

    def _solve(
        self,
        eval_rhs,
        eval_jac_mul,
        eval_jac,
        check_step,
        check_final,
        Phi0,
        Phi00,
        tol=1e-6,
        maxiter=30,
        preconditioner=None,
        solver=None,
        restart=0,
        finite_difference=True,
    ):
        A_size = 2 * Phi0.size
        A_shape = (A_size, A_size)
        A_dtype = np.dtype(float)

        Phi00 = Phi00.copy()
        self._core.fix_terminals(Phi00)

        F0 = eval_rhs(Phi00)
        F0_norm = _norm(F0)

        if F0_norm == 0:
            F0_norm = tol

        outer_v = []

        if self._M is not None and self._M.shape == A_shape:
            M = self._M
        else:
            M = None
        self._M = None

        Phi = Phi0

        update_skipped = False
        tiny_count = 0
        bad_count = 0
        skip_precond = 0

        if solver is None:
            solver = "lgmres"

        if A_size > 40000:
            _log_solve.debug(
                f"    Problem too large ({A_size}): cannot compute hessian"
            )
            preconditioner = "none"

        if not np.isfinite(Phi).all():
            Phi = np.zeros_like(Phi)
            M = self._M = None

        if M is None:
            self._t_jac = 0.0
            self._t_solve = 0.0
            self._t_solve_min = np.inf

        self._core.fix_terminals(Phi)

        for j in range(maxiter):
            if skip_precond > 0:
                M = None
                skip_precond -= 1
            elif M is None and preconditioner != "none":
                # Update preconditioner
                _log_solve.debug("    Update Jacobian")
                first_jac = not self._core.hess_computed
                t0 = time.time()
                J = eval_jac(Phi)

                _log_solve.debug("    Form preconditioner")
                M = None
                M = self._get_preconditioner(J, preconditioner=preconditioner)
                J = None
                _log_solve.debug("    Preconditioner formed")
                frozen_M = False
                self._t_jac = time.time() - t0 if not first_jac else 0.0
                self._t_solve = 0.0
                self._t_solve_min = np.inf
            else:
                # Don't update preconditioner
                frozen_M = True

            F = eval_rhs(Phi)
            F_norm = _norm(F)
            S_value = self._core.eval(Phi).real
            _log_solve.info(
                f"    #{j}: residual = {F_norm / F0_norm}, Re S = {S_value}"
            )

            if F_norm <= tol * F0_norm:
                if check_final(Phi):
                    break
                else:
                    Phi[:, :, 0] *= -0.5
                    Phi[:, :, 1] *= 0.5
                    self._core.fix_terminals(Phi)
                    _log_solve.info("    Force restart: wrong branch")

            # Solve linearized problem
            count = 0

            Phi_norm = _norm(Phi)

            def op(v, rdiff=1e-4):
                nonlocal count
                count += 1
                v = v.view(np.complex128).reshape(Phi.shape)
                if finite_difference:
                    scale = rdiff * max(1, Phi_norm) / max(1, _norm(v))
                    return (eval_rhs(Phi + scale * v) - F) / scale
                else:
                    return eval_jac_mul(Phi, v)

            A = LinearOperator(matvec=op, shape=A_shape, dtype=A_dtype)

            l_maxiter = 2
            l_inner_m = 15
            l_outer_k = 5

            t0 = time.time()

            if solver == "cg":
                dx, info = cg(
                    A,
                    -F,
                    M=M,
                    tol=1e-3,
                    atol=0,
                    maxiter=l_maxiter,
                )
            elif solver == "tfqmr":
                dx, info = tfqmr(
                    A,
                    -F,
                    M=M,
                    tol=1e-3,
                    atol=0,
                    maxiter=l_maxiter,
                )
            elif solver == "lgmres":
                dx, info = lgmres(
                    A,
                    -F,
                    M=M,
                    tol=1e-3,
                    atol=0,
                    maxiter=l_maxiter,
                    inner_m=l_inner_m,
                    outer_v=outer_v,
                    outer_k=l_outer_k,
                    store_outer_Av=False,
                    prepend_outer_v=True,
                )
            elif solver == "gcrotmk":
                dx, info = gcrotmk(
                    A,
                    -F,
                    M=M,
                    tol=1e-3,
                    atol=0,
                    maxiter=l_maxiter,
                    m=l_inner_m,
                )
            elif solver == "gmres":
                dx, info = gmres(
                    A,
                    -F,
                    M=M,
                    tol=1e-3,
                    atol=0,
                    maxiter=l_maxiter,
                    restart=l_inner_m,
                )
            else:
                raise ValueError(f"Unknown {solver=}")

            dt = time.time() - t0
            self._t_solve_min = min(dt, self._t_solve_min)
            self._t_solve += dt - self._t_solve_min

            if count >= l_maxiter * l_inner_m // 2 and self._t_solve > self._t_jac:
                M = None

            # Update
            dx = dx.view(np.complex128).reshape(Phi.shape)
            s = abs(dx).max()

            _log_solve.debug(f"    {count} matvec, step {s}, norm {Phi_norm}")

            do_restart = False

            if not check_step(Phi, dx):
                M = None
                if frozen_M and not update_skipped:
                    _log_solve.info("    Bad step: force Jacobian update")
                    update_skipped = True
                    continue
                elif restart == 0:
                    do_restart = True
                else:
                    for k in range(1, 5):
                        if check_step(Phi, dx / 5**k):
                            dx /= 5**k
                            s /= 5**k
                            _log_solve.debug(f"    bad step, reduced by {5**k}")
                            break
                    else:
                        bad_count += 1
                        if bad_count > 3:
                            _log_solve.debug(f"    bad step, giving up")
                            Phi[...] = np.nan
                            break
                        else:
                            _log_solve.debug(f"    bad step, nothing helped")
            else:
                bad_count = 0

            if s >= 0.75:
                M = None
                dx *= 0.75 / s
                tiny_count = 0
            elif s == 0:
                if restart == 0:
                    do_restart = True
                else:
                    skip_precond = tiny_count + 2
            elif s < 1e-10:
                tiny_count += 1
                if tiny_count > 5:
                    if restart == 0:
                        do_restart = True
                    else:
                        skip_precond = tiny_count + 2

            if do_restart:
                _log_solve.info("  Force restart from Ansatz")
                Phi[...] = 0
                self._core.fix_terminals(Phi)
                old_omega = self._core.omega
                abs_re_w = abs(np.real(old_omega))
                if 0 < abs_re_w < 1:
                    # Continue from larger |Re omega|
                    minw = max(1e-6, 0.25 * abs_re_w)
                    steps = [0.75**k for k in range(50) if 0.75**k > abs_re_w / 5]
                    try:
                        for domega in steps:
                            _log_solve.debug(f"  Continuation step {domega}")
                            self._core.omega = old_omega + domega * (
                                1 if np.real(old_omega) >= 0 else -1
                            )
                            self._M = M = None
                            self._Phi[...] = self._solve(
                                eval_rhs,
                                eval_jac_mul,
                                eval_jac,
                                check_step,
                                check_final,
                                Phi,
                                Phi00,
                                tol=tol,
                                maxiter=maxiter,
                                preconditioner=preconditioner,
                                restart=1,
                                solver=solver,
                            )
                            if np.isnan(self._Phi).any():
                                break
                    finally:
                        self._core.omega = old_omega
                    if np.isnan(self._Phi).any():
                        break
                    _log_solve.debug(f"    continuation step 0")
                self._M = M = None
                return self._solve(
                    eval_rhs,
                    eval_jac_mul,
                    eval_jac,
                    check_step,
                    check_final,
                    Phi,
                    Phi00,
                    tol=tol,
                    maxiter=maxiter,
                    preconditioner=preconditioner,
                    restart=1,
                    solver=solver,
                )

            Phi += dx
            update_skipped = False
        else:
            # Fail to solve
            Phi[...] = np.nan

        if np.isnan(Phi).any():
            _log_solve.error("Failed to solve")
            Phi[...] = np.nan
            M = None

        self._M = M

        return Phi

    def _check_step(self, Phi, dx):
        return self._check_final(Phi + dx)

    def _check_final(self, Phi):
        # Check branch
        g = Phi[:, :, 0]
        gt = Phi[:, :, 1]
        I = np.eye(2)
        s = np.sign(self._core.omega)

        G = np.linalg.solve(I + g @ gt, I - g @ gt)
        if G[:, :, 0, 0].real.min() < -0.05 or G[:, :, 1, 1].real.min() < -0.05:
            return False
        return True

    def _eval_rhs(self, Phi):
        return self._core.grad(Phi).ravel().view(np.float_)

    def _eval_jac_mul(self, Phi, dPhi):
        return self._core.hess_mul(Phi, dPhi).ravel().view(np.float_)

    def _eval_jac(self, Phi):
        if not self._core.hess_computed:
            est = 1.3e-4 * (self.shape[0] * self.shape[1]) ** 2
            _log_solve.info(
                f"Forming Jacobian (probably takes around {est/60:.0f} min)..."
            )

        # complex jacobian of S
        H = self._core.hess(Phi)
        # map to real jacobian of Re S
        i2 = np.r_[2 * H.row, 2 * H.row + 1, 2 * H.row, 2 * H.row + 1]
        j2 = np.r_[2 * H.col, 2 * H.col, 2 * H.col + 1, 2 * H.col + 1]
        d2 = np.r_[H.data.real, -H.data.imag, -H.data.imag, -H.data.real]
        return coo_matrix(
            (d2, (i2, j2)),
            dtype=np.float_,
            shape=(H.shape[0] * 2, H.shape[1] * 2),
            copy=False,
        ).tocsc()

    @functools.wraps(_solve)
    def solve(self, omega, **kw):
        _log_solve.info(f"omega = {omega}")
        nx, ny = self.shape
        Phi00 = np.zeros_like(self._Phi)
        self._core.omega = omega
        self._core.fix_terminals(Phi00)
        self._Phi = self._solve(
            self._eval_rhs,
            self._eval_jac_mul,
            self._eval_jac,
            self._check_step,
            self._check_final,
            self._Phi,
            Phi00,
            **kw,
        )
        return Result(self, omega=omega, Phi=self._Phi)

    def _get_preconditioner(self, J, preconditioner=None):
        if preconditioner is None:
            sz = self.shape[0] * self.shape[1] / sum(self.shape)
            if sz > 1000 and pyamg is not None:
                preconditioner = "pyamg"
            elif umfpack is not None:
                preconditioner = "umfpack" if sz < 100 else "umfpack-ilu"
            else:
                preconditioner = "splu"
            _log_solve.debug(f"    preconditioner: {preconditioner}")

        if preconditioner == "splu":
            try:
                J_lu = splu(J.tocsc(copy=True))
            except RuntimeError:
                _log_solve.error("Preconditioner failed: using none")
                return None
            M = LinearOperator(matvec=J_lu.solve, shape=J.shape, dtype=J.dtype)
        elif preconditioner == "spilu":
            try:
                J_lu = spilu(J.tocsc(copy=True), drop_tol=1e-4, fill_factor=2.0)
            except RuntimeError:
                _log_solve.error("Preconditioner failed: using none")
                return None
            M = LinearOperator(matvec=J_lu.solve, shape=J.shape, dtype=J.dtype)
        elif preconditioner == "pyamg" and pyamg is not None:
            M = pyamg.ruge_stuben_solver(J.tocsr(copy=True)).aspreconditioner()
        elif preconditioner == "umfpack" and umfpack is not None:
            M = UmfpackILU(J.tocsc(copy=True), drop_tol=0)
        elif preconditioner == "umfpack-ilu" and umfpack is not None:
            M = UmfpackILU(J.tocsc(copy=True), drop_tol=1e-3)
        else:
            raise ValueError(f"Unknown {preconditioner=}")

        return M

    def solve_many(self, omega, **kw):
        omega0 = np.asarray(omega)
        omega = omega0.ravel()

        Phi = np.zeros(omega.shape + self.Phi.shape, dtype=self.Phi.dtype)

        js = np.lexsort((-abs(omega), np.sign(omega.real)))

        self._t_jac = 0.0
        self._t_solve = 0.0
        self._t_solve_min = np.inf
        self._M = None
        self._Phi[...] = 0

        for j in js:
            self.solve(omega=omega[j], **kw)
            Phi_prev = self._Phi.copy()
            Phi[j] = self._Phi
            if np.isnan(self._Phi).any():
                self._Phi[...] = Phi_prev

        Phi = Phi.reshape(omega0.shape + Phi.shape[1:])

        return Result(self, omega=omega0, Phi=Phi)

    def self_consistency(self, T, T_c0, continuation=False, **kw):
        nx, ny = self.shape
        T_c0 = np.broadcast_to(T_c0, (nx, ny))
        self.Delta[self.mask == MASK_VACUUM] = 0
        if not continuation:
            mask = self.mask == MASK_NONE
            self.Delta[mask] = 2 * T_c0[mask, None, None] * np.eye(2)
        res = _self_consistency(self, T, T_c0, **kw)
        self.Delta[...], self.J_tot[...], cx, success = res
        return self.Delta, self.J_tot, cx, success


def _bcast_left(x, shape):
    x = np.asarray(x)
    return x[(Ellipsis,) + (None,) * (len(shape) - x.ndim)]


class Result:
    def __init__(self, parent, omega, Phi, core=None):
        self._core = core if core is not None else parent._core
        self.shape = parent.shape
        self.Lx = parent.Lx
        self.Ly = parent.Ly
        self.omega = omega
        self.Omega = parent.Omega
        self.Phi = Phi
        self.mask = parent.mask

    @property
    def Delta(self):
        return _DeltaArray(self.Omega)

    @property
    def G(self):
        g = self.Phi[..., 0, :, :]
        gt = self.Phi[..., 1, :, :]
        I = np.eye(2)
        s = _bcast_left(np.sign(self.omega), g.shape)
        with np.errstate(invalid="ignore"):
            return s * np.linalg.solve(I + g @ gt, I - g @ gt)

    @property
    def Gc(self):
        g = self.Phi[..., 0, :, :]
        gt = self.Phi[..., 1, :, :]
        I = np.eye(2)
        s = _bcast_left(np.sign(self.omega), g.shape)
        with np.errstate(invalid="ignore"):
            return -s * np.linalg.solve(I + gt @ g, I - gt @ g)

    @property
    def F(self):
        g = self.Phi[..., 0, :, :]
        gt = self.Phi[..., 1, :, :]
        I = np.eye(2)
        s = _bcast_left(np.sign(self.omega), g.shape)
        with np.errstate(invalid="ignore"):
            return 2 * s * np.linalg.solve(I + g @ gt, g)

    @property
    def Fc(self):
        g = self.Phi[..., 0, :, :]
        gt = self.Phi[..., 1, :, :]
        I = np.eye(2)
        s = _bcast_left(np.sign(self.omega), g.shape)
        with np.errstate(invalid="ignore"):
            return 2 * s * np.linalg.solve(I + gt @ g, gt)

    @property
    def g(self):
        g = self.Phi[..., 0, :, :]
        gt = self.Phi[..., 0, :, :]
        I = np.eye(2)
        s = _bcast_left(np.sign(self.omega), g.shape)
        with np.errstate(invalid="ignore"):
            G = self.G
            return s * np.linalg.solve(I + gt @ g, I - gt @ g)

    def _get_J(self, omega, Phi):
        old_omega = self._core.omega
        self._core.omega = omega
        try:
            return self._core.grad_A(Phi).transpose(0, 1, 2, 4, 3) * (1 / 16)
        finally:
            self._core.omega = old_omega

    def _get_S(self, omega, Phi):
        old_omega = self._core.omega
        self._core.omega = omega
        try:
            return self._core.grad_Omega(Phi).transpose(0, 1, 3, 2)
        finally:
            self._core.omega = old_omega

    @property
    def S(self):
        if np.asarray(self.omega).ndim == 0:
            return self._get_J(self.omega, self.Phi)
        else:
            return np.array(
                [self._get_S(w, P) for w, P in zip(self.omega, self.Phi)], dtype=complex
            )

    @property
    def J(self):
        if np.asarray(self.omega).ndim == 0:
            return self._get_J(self.omega, self.Phi)
        else:
            return np.array(
                [self._get_J(w, P) for w, P in zip(self.omega, self.Phi)], dtype=complex
            )

    @property
    def J_c(self):
        t3 = np.diag([1, 1, -1, -1])
        return tr(self.J @ t3)

    @property
    def J_s(self):
        jx = tr(self.J @ np.kron(s0, sx))
        jy = tr(self.J @ np.kron(s0, sy))
        jz = tr(self.J @ np.kron(s0, sz))
        return np.array([jx, jy, jz]).transpose(1, 2, 3, 0)


class UmfpackILU(LinearOperator):
    def __init__(self, M, drop_tol=1e-4):
        if M.dtype != np.float64 or M.shape[0] != M.shape[1] or M.format != "csc":
            raise ValueError("Only square, csc, float64 supported")

        super().__init__(dtype=M.dtype, shape=M.shape)

        self.M = M
        self.ctx = umfpack.UmfpackContext(
            family=("di" if M.indptr.dtype == np.int32 else "dl")
        )
        if drop_tol:
            self.ctx.control[umfpack.UMFPACK_DROPTOL] = drop_tol
        self.ctx.numeric(self.M)

    def _matvec(self, v):
        return self.ctx.solve(umfpack.UMFPACK_A, self.M, v)


def _get_workers(workers):
    if workers is None or isinstance(workers, concurrent.futures.Executor):
        return workers, False

    mpi_env_vars = ("PMI_RANK", "PMIX_RANK", "OMP_COMM_WORLD_RANK")

    if workers == -1 and any(vn in os.environ for vn in mpi_env_vars):
        workers = "mpi"

    if workers == "mpi":
        _log.info("Trying to use MPI")
        import mpi4py.futures
        import mpi4py.MPI

        pool = mpi4py.futures.MPIPoolExecutor()
        pool.__num_workers = mpi4py.futures._lib.get_max_workers()

        if mpi4py.MPI.COMM_WORLD.Get_size() != 1:
            _log.warning(
                "Running with multiple MPI processes. Set universe size instead: mpirun -usize 6 ..."
            )
        if pool.__num_workers == 1:
            _log.warning(
                "Running with MPI pool of size 1. Try setting a universe size: mpirun -usize 6 ..."
            )

        return pool, True

    workers = int(workers)
    if workers < 0:
        ncpu = os.cpu_count()
        workers = min(32, ncpu + 1 + workers)

    if workers > 1:
        pool = concurrent.futures.ProcessPoolExecutor(max_workers=workers)
        pool.__num_workers = workers
        return pool, True
    else:
        return None, False


def _with_workers(func):
    @functools.wraps(func)
    def wrapper(*a, **kw):
        pool, own_pool = _get_workers(kw.pop("workers", 1))
        try:
            kw["workers"] = pool
            return func(*a, **kw)
        finally:
            if pool is not None and own_pool:
                pool.shutdown()

    return wrapper


@_with_workers
def _self_consistency(
    solver,
    T,
    T_c0,
    maxiter=100,
    tol=1e-5,
    xtol=1e-2,
    workers=1,
    constraint_fun=None,
    constraint_x=None,
    plot=False,
    **solver_kw,
):
    """
    Solve self-consistency conditions, with additional constraint variables.

    Parameters
    ----------
    ...
    constraint_fun : function(Delta, x) -> residual
        Function returning additional components of the self-consistency
        residual, where `x` are additional constraint variables.
        This can be used e.g. for pseudoarclength continuation.
    constraint_x : array
        Initial values of the additional constraint variables.
    """
    T = float(T)
    T_c0 = np.asarray(T_c0)

    if T_c0.shape + (2, 2) != solver.Delta.shape:
        raise ValueError("T_c0 must have compatible shape with Delta")

    Delta_mask = (solver.mask == MASK_NONE) & (T_c0 > 0) & (T < T_c0)

    z = solver.Delta[Delta_mask].copy().ravel().view(float)
    if constraint_fun is not None:
        constraint_x = np.asarray(constraint_x)
        z = np.r_[constraint_x.astype(float), z]
    else:
        constraint_x = np.array([])

    A_shape = (z.size, z.size)
    A_dtype = np.float_

    outer_v = []
    count = 0

    if Delta_mask.any():
        F0_norm = len(z) * max(T, abs(z).max(), abs(T_c0).max())
    else:
        # Nothing to solve
        F0_norm = 1.0
    dx_norm = 0

    success = True
    J = None

    solver_kw["tol"] = 0.1 * tol

    cache = collections.deque([], 5)

    for j in range(maxiter):
        # Solve linearized problem

        def ev(z):
            nonlocal J

            for k, v in cache:
                if (z == k).all():
                    return v.copy()

            cx = z[: constraint_x.size].reshape(constraint_x.shape)
            Delta1 = solver.Delta.copy()
            Delta1[Delta_mask] = z[constraint_x.size :].view(complex).reshape(-1, 2, 2)

            if constraint_fun is not None:
                res_cons = constraint_fun(cx, Delta1).ravel()
                # Constraint function updates terminal phases
                solver.Delta[...] = Delta1

            res, J, m = _self_consistent_Delta_f(
                solver, Delta1, T, T_c0, workers=workers, **solver_kw
            )
            res = res[m].ravel().view(float)
            if constraint_fun is not None:
                res = np.r_[res_cons, res]
            cache.append((z.copy(), res.copy()))
            return res

        def op(v, rdiff=1e-4):
            nonlocal count
            count += 1
            scale = rdiff * max(1, _norm(z)) / max(1, _norm(v))
            return (ev(z + scale * v) - F) / scale

        A = LinearOperator(matvec=op, shape=A_shape, dtype=A_dtype)

        if constraint_fun is not None:
            constraint_fun.doplot = True
        F = ev(z)
        if F.size == 0:
            # Nothing to solve
            break

        F_norm = _norm(F)
        if constraint_fun is not None:
            constraint_fun.doplot = False

        z_norm = _norm(z)

        _log.info(
            f"self-cons. #{j}: residual = {F_norm/F0_norm}, dx = {dx_norm/z_norm}"
        )

        if F_norm < tol * F0_norm and dx_norm < xtol * z_norm and j > 0:
            break

        l_inner_m = 3
        l_maxiter = 1

        dx, info = lgmres(
            A,
            -F,
            tol=0.1 * tol,
            atol=0,
            maxiter=l_maxiter,
            inner_m=l_inner_m,
            outer_v=outer_v,
            outer_k=2,
            store_outer_Av=False,
            prepend_outer_v=True,
        )
        assert not np.isnan(dx).any()

        s = np.linalg.norm(dx, np.inf) / np.linalg.norm(z, np.inf)
        if s >= 0.1:
            s = 0.1 / s
            dx *= s

        z += dx
        dx_norm = _norm(dx)

        if plot:
            import matplotlib.pyplot as plt
            from .plotting import pcolormesh_complex

            x = np.linspace(-solver.Lx / 2, solver.Lx / 2, solver.shape[0])
            y = np.linspace(-solver.Ly / 2, solver.Ly / 2, solver.shape[1])

            plt.clf()
            ddd = solver.Delta.copy()
            ddd[Delta_mask] = z[constraint_x.size :].view(complex).reshape(-1, 2, 2)

            ddd = ddd[..., 0, 0]

            if plot == "circle" or min(ddd.shape) == 1:
                ddd = ddd.squeeze()
                plt.plot(np.real(ddd).ravel(), np.imag(ddd).ravel(), ".")
                plt.xlabel(r"$\mathrm{Re} \Delta$")
                plt.ylabel(r"$\mathrm{Im} \Delta$")
                dmax = max(1, abs(ddd).max())
                plt.xlim(-dmax, dmax)
                plt.ylim(-dmax, dmax)
            else:
                from .plotting import plot_Delta_J_2d

                plot_Delta_J_2d(x, y, ddd, J)

            plt.gca().set_aspect("equal")
            plt.pause(0.001)
    else:
        _log.error(f"Self-consistent iteration didn't converge (with {count} ops)")
        success = False

    Delta = solver.Delta.copy()
    Delta[Delta_mask] = z[constraint_x.size :].view(complex).reshape(-1, 2, 2)
    cx = z[: constraint_x.size]

    if success and (np.isnan(J).any() or np.isnan(Delta).any()):
        _log.error(f"Self-consistent iteration failed: nan result (with {count} ops)")
        success = False

    if success:
        _log.info(f"Self-consistent iteration converged with {count} ops")

    return Delta, J, cx, success


@_with_workers
def _self_consistent_Delta_f(
    solver, Delta, T, T_c0, E_typical=None, workers=1, **solver_kw
):
    r"""
    Evaluate f = \sum_(Delta/omega - F) - Delta log(Tc/T).
    Self-consistency is achieved when f = 0.
    """
    # Sum over positive frequencies (reverse order)
    E_typical = 10 * abs(T_c0).max() + 10 * T
    w, a = get_matsubara_sum(T, E_typical)
    w = w[len(w) // 2 :][::-1]
    a = a[len(a) // 2 :][::-1]

    rtot = np.zeros(solver.Delta.shape, dtype=complex)
    Jtot = np.zeros(tuple(solver.shape) + (4, 4, 4), dtype=complex)

    mask1 = solver.mask == MASK_NONE
    mask2 = (T_c0 > 0) & (T < T_c0)
    rtot[mask1 & ~mask2] = Delta[mask1 & ~mask2]
    mask = mask1 & mask2

    solver.Delta[mask] = Delta[mask]

    if workers is not None:
        jobs = []
        nworkers = getattr(workers, "__num_workers", None)
        if nworkers is not None:
            chunk = len(w) // nworkers + 1
        else:
            chunk = 30
        for j in range(0, len(w), chunk):
            jobs.append((w[j : j + chunk], a[j : j + chunk], solver, solver_kw))
        work = lambda: workers.map(_mp_one, jobs, chunksize=1)
    else:
        work = lambda: (_mp_one((w, a, solver, solver_kw)),)

    for rtotx, Jtotx in work():
        rtot[mask] += rtotx[mask]
        Jtot += Jtotx

    rtot[mask] -= np.log(T_c0[mask, None, None] / T) * solver.Delta[mask]

    return rtot, Jtot, mask


_prev_solver = None


def _mp_one(args, solver=None, solver_kw=None):
    global _prev_solver

    w, a, solver, solver_kw = args

    if (
        _prev_solver is not None
        and _prev_solver is not solver
        and solver._id == _prev_solver._id
    ):
        # Only dynamic variables may have changed
        _prev_solver.Omega[...] = solver.Omega
        solver = _prev_solver
    else:
        _prev_solver = solver

    old_level = _log_solve.level
    try:
        _log_solve.setLevel(_log.getEffectiveLevel() + 10)

        rtot = 0
        Jtot = 0

        for wx, ax in zip(w, a):
            res = solver.solve(omega=wx, **solver_kw)
            r = solver.Delta.A / abs(wx) - res.F
            rtot += (2 * np.pi * ax) * r
            Jtot += (-2j * np.pi * ax) * res.J
    finally:
        _log_solve.setLevel(old_level)

    return rtot, Jtot


class CPRState:
    def __init__(self, solver, extra_params):
        self.phis = []
        self.Deltas = []
        self.Js = []
        self.dss = []

        self.solver = solver
        self.extra_params = extra_params

    @classmethod
    def _get_hash(self, solver, extra_params):
        data = (
            solver.shape,
            solver.mask,
            solver.Lx,
            solver.Ly,
            extra_params,
        )
        h = hashlib.sha512()
        h.update(pickle.dumps(data))
        return np.asarray(list(h.digest()), dtype=np.uint8)

    def save(self, filename):
        fd, fn = tempfile.mkstemp(
            dir=os.path.dirname(filename), prefix=filename, suffix=".new.npz"
        )
        try:
            np.savez(
                io.open(fd, mode="w+b", closefd=True),
                params_hash=self._get_hash(self.solver, self.extra_params),
                phis=np.asarray(self.phis),
                Deltas=np.asarray(self.Deltas),
                Js=np.asarray(self.Js),
                dss=np.asarray(self.dss),
                mask=self.solver.mask,
                Lx=self.solver.Lx,
                Ly=self.solver.Ly,
                x=self.solver.x,
                y=self.solver.y,
            )
        except:
            os.unlink(fn)
            raise
        os.rename(fn, filename)

    def load(self, filename):
        with np.load(filename) as f:
            try:
                h1 = self._get_hash(self.solver, self.extra_params)
                h2 = f["params_hash"]
                if h1.shape != h2.shape or not (h1 == h2).all():
                    return False

                self.phis = list(f["phis"])
                self.Deltas = list(f["Deltas"])
                self.Js = list(f["Js"])
                self.dss = list(f["dss"])
                return True
            except KeyError:
                return False


@_with_workers
def cpr(
    solver,
    T,
    T_c0,
    phase_mask,
    max_points=1000,
    auto_stop=True,
    ds=0.2,
    maxiter=10,
    phi0=0,
    filename=None,
    **selfcons_kw,
):
    nx, ny = solver.shape

    phase_mask = phase_mask.astype(bool)

    theta = 1 / (1 + nx * ny / 100)  # Pseudoarclength weight
    ds_scale = np.sqrt(1 + nx * ny)

    def get_ds(Delta_1, Delta_0, phi_1, phi_0):
        # Pseudoarclength size
        m = solver.mask == MASK_NONE
        return np.sqrt(
            theta * _norm(Delta_1[m] - Delta_0[m]) ** 2
            + (1 - theta) * abs(phi_1 - phi_0) ** 2
        )

    if (solver.mask[phase_mask] != MASK_TERMINAL).any():
        raise ValueError("Phase mask has points that don't have MASK_TERMINAL")

    Delta_00 = abs(solver.Delta.A)

    state = CPRState(solver=solver, extra_params=(T, T_c0, phase_mask))

    if filename is not None and os.path.isfile(filename):
        if not state.load(filename):
            _log.warning(f"CPR: parameters changed, not loading '{filename}'")
        else:
            _log.warning(f"CPR: loading data from '{filename}' and continuing")

    if len(state.phis) == 0:
        phi = phi0

        solver.Delta[...] = Delta_00
        solver.Delta[phase_mask] = Delta_00[phase_mask] * np.exp(1j * phi)
        solver.self_consistency(T=T, T_c0=T_c0, **selfcons_kw)

        state.Deltas = [solver.Delta.copy()]
        state.phis = [0]
        state.dss = [0]
        state.Js = [solver.J_tot.copy()]

    phi = state.phis[-1] + ds / 4

    if len(state.phis) == 1:
        solver.Delta[phase_mask] = Delta_00[phase_mask] * np.exp(1j * phi)
        solver.self_consistency(T=T, T_c0=T_c0, continuation=True, **selfcons_kw)

        state.dss.append(
            float(get_ds(solver.Delta, state.Deltas[-1], phi, state.phis[-1]))
        )
        state.Deltas.append(solver.Delta.copy())
        state.phis.append(float(phi))
        state.Js.append(solver.J_tot.copy())

    for j in range(len(state.phis), max_points):
        if auto_stop and len(state.phis) > 2:
            if state.phis[-2] > np.pi and state.phis[-1] <= np.pi:
                _log.info(f"CPR full cycle obtained: stop (pi reverse cross)")
                break
            t3 = np.diag([1, 1, -1, -1])
            I1 = tr(state.Js[-1] @ t3)
            I2 = tr(state.Js[-2] @ t3)
            if (
                state.phis[-2] < np.pi
                and state.phis[-1] >= np.pi
                and np.vdot(I1.ravel(), I2.ravel()).real <= 0
            ):
                _log.info(f"CPR full cycle obtained: stop (pi cross, J sign change)")
                break

        ds_cur = min(ds, 1.1 * state.dss[-1])

        dDelta = (state.Deltas[-1] - state.Deltas[-2]) / state.dss[-1]
        dphi = (state.phis[-1] - state.phis[-2]) / state.dss[-1]

        turnaround = False
        turnaround_large = False

        for k in range(10):
            phi = state.phis[-1] + dphi * ds_cur
            Delta = state.Deltas[-1] + dDelta * ds_cur

            m = solver.mask == MASK_NONE
            solver.Delta[m] = Delta[m]

            solver.Delta[phase_mask] = Delta_00[phase_mask] * np.exp(1j * phi)

            # Corrector
            def constraint_fun(phi, Delta):
                # Pseudoarclength constraint
                Delta[phase_mask] = Delta_00[phase_mask] * np.exp(1j * phi)

                m = solver.mask == MASK_NONE
                dsa = theta * np.real(
                    np.vdot(dDelta[m], Delta[m] - state.Deltas[-1][m])
                )
                dsb = (1 - theta) * dphi * (phi - state.phis[-1])
                dstot = dsa + dsb
                # dstot = get_ds(Delta, Deltas[-1], phi, phis[-1])
                dr = (dstot - ds_cur) * ds_scale

                return np.asarray(dr)

            constraint_x = np.asarray(phi).ravel()
            Delta, J_tot, cx, success = solver.self_consistency(
                T=T,
                T_c0=T_c0,
                constraint_x=constraint_x,
                constraint_fun=constraint_fun,
                continuation=True,
                maxiter=maxiter,
                **selfcons_kw,
            )

            if not success:
                ds_cur /= 2.0
                _log.info(f"CPR #{j}: retry with step {ds_cur}")
                continue

            if np.sign(cx - state.phis[-1]) != np.sign(dphi):
                if not turnaround:
                    # Calculate at smaller stepsize to verify turnaround
                    turnaround = True
                    ds_cur /= 1.5
                    _log.info(f"CPR #{j}: turning around: verifying with step {ds_cur}")
                    continue
                if not turnaround_large:
                    # Calculate at larger stepsize to verify turnaround
                    turnaround_large = True
                    ds_cur *= 2
                    _log.info(f"CPR #{j}: turning around: verifying with step {ds_cur}")
                    continue

            phi = cx
            break
        else:
            raise RuntimeError("Failed to converge")

        ds_cur = get_ds(Delta, state.Deltas[-1], phi, state.phis[-1])
        state.dss.append(float(ds_cur))
        state.Deltas.append(Delta.copy())
        state.phis.append(float(phi))
        state.Js.append(solver.J_tot.copy())

        t3 = np.diag([1, 1, -1, -1])
        J_avg = np.mean(tr(solver.J_tot[:, :, 0].real @ t3))
        _log.info(f"CPR #{j}: phi = {phi}, J = {J_avg}")

        if filename is not None:
            state.save(filename)

    return np.asarray(state.phis), np.asarray(state.Deltas), np.asarray(state.Js)


def tr(M):
    """
    Trace over last two dimensions
    """
    return np.trace(M, axis1=-2, axis2=-1)