# -*- eval: (blacken-mode) -*-
import logging
import pytest
import numpy as np
from numpy.testing import assert_allclose
from scipy.linalg import expm

import usadelndsoc
import usadelndsoc.solver

usadelndsoc.logger.setLevel(logging.DEBUG)


def basic_setup(terminals=False, nx=15, ny=5, h=0):
    s = usadelndsoc.solver.Solver(nx=nx, ny=ny)
    s.Delta[...] = np.eye(2)
    s.Omega[...] += np.diag([1, -1, -1, 1]) * h
    s.eta = 0.0
    s.Lx = 10
    s.Ly = 10

    if terminals:
        s.mask[0, :] = usadelndsoc.solver.MASK_TERMINAL
        s.mask[-1, :] = usadelndsoc.solver.MASK_TERMINAL

    return s


# s.solve(omega=150, maxiter=300, preconditioner="none")


def test_solve_dos():
    s = basic_setup()
    E = np.linspace(-3, 3, 101) + 0.05j
    res = s.solve_many(omega=-1j * E)
    g = res.G[:, 10, 2, 0, 0]
    g_an = -1j * E / np.sqrt(1 - E**2)
    assert_allclose(g, g_an, rtol=1e-4)


@pytest.mark.parametrize(
    "omega,h,terminals",
    [
        (w, h, terminals)
        for w in (0.5 + 0.8j, -0.5 + 0.8j)
        for h in (0, -0.5, 0.5)
        for terminals in (True, False)
    ],
)
def test_result_S_J0(omega, h, terminals):
    s = basic_setup(h=h, terminals=terminals)
    res = s.solve_many(omega=omega)
    S = res.S

    vol = s.Lx * s.Ly / (s.shape[0] * s.shape[1]) * 4j

    assert_allclose(S[:, :, :2, :2], vol * res.G)
    assert_allclose(S[:, :, :2, 2:], vol * res.F)
    assert_allclose(S[:, :, 2:, :2], vol * res.Fc)
    assert_allclose(S[:, :, 2:, 2:], vol * res.Gc)

    # Check vs. analytic
    print(omega, h)
    wp = omega - 1j * h
    wm = omega + 1j * h
    Gan = np.diag([wp / np.sqrt(wp**2 + 1), wm / np.sqrt(wm**2 + 1)])
    Fan = np.diag([1 / np.sqrt(wp**2 + 1), 1 / np.sqrt(wm**2 + 1)])
    Fcan = np.diag([1 / np.sqrt(wp**2 + 1), 1 / np.sqrt(wm**2 + 1)])
    Gcan = -np.diag([wp / np.sqrt(wp**2 + 1), wm / np.sqrt(wm**2 + 1)])

    shp = res.G.shape

    assert_allclose(res.G, np.broadcast_to(Gan, shp), rtol=1e-5)
    assert_allclose(res.F, np.broadcast_to(Fan, shp), rtol=1e-5)
    assert_allclose(res.Fc, np.broadcast_to(Fcan, shp), rtol=1e-5)
    assert_allclose(res.Gc, np.broadcast_to(Gcan, shp), rtol=1e-5)

    # There are no gradients, hence no currents either
    J = res.J
    assert_allclose(J, 0, atol=1e-5)


def test_gauge_invariance():
    nx = 10
    ny = 7
    h = 0.4

    np.random.seed(1)

    s = usadelndsoc.solver.Solver(nx=nx, ny=ny)
    s.Omega[...] = 0
    s.U[:, :, 0, :2, :2] = np.random.randn(nx, ny, 2, 2) + 1j * np.random.randn(
        nx, ny, 2, 2
    )
    s.U[:, :, 1, :2, :2] = np.random.randn(nx, ny, 2, 2) + 1j * np.random.randn(
        nx, ny, 2, 2
    )
    s.U[:, :, 0, 2:, 2:] = np.random.randn(nx, ny, 2, 2) + 1j * np.random.randn(
        nx, ny, 2, 2
    )
    s.U[:, :, 1, 2:, 2:] = np.random.randn(nx, ny, 2, 2) + 1j * np.random.randn(
        nx, ny, 2, 2
    )
    s.omega = 0.987
    s.eta = 0.789 * 0
    s.D = 1.0
    s.Lx = 10
    s.Ly = 10

    s._core.init_U()

    Phi = np.random.randn(nx, ny, 2, 2, 2) + 1j * np.random.randn(nx, ny, 2, 2, 2)
    S0 = s._core.eval(Phi)

    s.reset()

    # Gauge transform
    i = 3
    j = 4

    W = np.zeros((4, 4), dtype=complex)
    W[:2, :2] = np.random.randn(2, 2)
    W[2:, 2:] = np.random.randn(2, 2)
    iW = np.linalg.inv(W)

    s.U[i, j, 0] = s.U[i, j, 0] @ iW
    s.U[i, j, 1] = s.U[i, j, 1] @ iW
    s.U[i, j, 2] = s.U[i, j, 2] @ iW
    s.U[i, j, 3] = s.U[i, j, 3] @ iW

    s.U[i + 1, j, 2] = W @ s.U[i + 1, j, 2]
    s.U[i - 1, j, 0] = W @ s.U[i - 1, j, 0]

    s.U[i, j + 1, 3] = W @ s.U[i, j + 1, 3]
    s.U[i, j - 1, 1] = W @ s.U[i, j - 1, 1]

    Phi2 = Phi.copy()
    Phi2[i, j, 0] = W[:2, :2] @ Phi[i, j, 0] @ iW[2:, 2:]
    Phi2[i, j, 1] = W[2:, 2:] @ Phi[i, j, 1] @ iW[:2, :2]

    S1 = s._core.eval(Phi2)

    # Should be invariant to numerical accuracy
    assert_allclose(S1, S0, atol=1e-9)

    # Current should be conserved to numerical accuracy
    # due to the gauge symmetry
    from usadelndsoc.solver import Result, tr

    res = Result(s, s.omega, Phi2)

    t3 = np.diag([1, 1, -1, -1])
    J = tr(res.J[i, j] @ t3)

    print(np.where(res.J))

    div_J_x = J[0] - J[2]
    div_J_y = J[1] - J[3]
    div_J = div_J_x + div_J_y

    print(J.tolist())
    assert_allclose(div_J, 0, atol=1e-9)