# -*- coding:utf-8; eval: (blacken-mode) -*-
"""
Calculate anomalous current in bulk case
"""
import joblib
import numpy as np
from numpy import pi
import matplotlib.pyplot as plt
from scipy.linalg import expm
import logging
import contextlib

import usadelndsoc
from usadelndsoc.matsubara import get_matsubara_sum
from usadelndsoc.solver import *
from usadelndsoc.util import vectorize_parallel

import collections

mem = joblib.Memory("cache")


def get_solver(soc_alpha, eta, phi, L=10, W=10, h=0.5, D=1.0, n=5):
    sol = Solver(nx=n, ny=n)
    sol.mask[...] = MASK_NONE
    sol.Lx = L
    sol.Ly = W

    sol.Omega[...] = 0
    sol.Delta[0, :] = np.eye(2) * np.exp(1j * phi / 2)
    sol.Delta[-1, :] = np.eye(2) * np.exp(-1j * phi / 2)

    sol.mask[0, :] = MASK_TERMINAL
    sol.mask[-1, :] = MASK_TERMINAL

    sol.D = D
    sol.eta = eta

    nx, ny = sol.shape
    dx = sol.Lx / nx
    dy = sol.Ly / ny

    Ax = soc_alpha * np.kron(S_0, S_y)
    Ay = -soc_alpha * np.kron(S_0, S_x)

    sol.Ux[...] = expm(1j * dx * Ax)
    sol.Uy[...] = expm(1j * dy * Ay)

    sol.Omega[1:-1] += h * np.kron(S_z, S_y)

    return sol


Res = collections.namedtuple("Res", ["x", "y", "J", "Jx", "Jy"])


@vectorize_parallel(returns_object=True, noarray=True)
@usadelndsoc.with_log_level(logging.WARNING)
def j(T, h, phi, n=15, eta=0.1, alpha_soc=0.1, L=10, W=2):
    sol = get_solver(soc_alpha=alpha_soc, eta=eta, L=L, W=W, n=n, phi=phi, h=h)

    E_typical = 10 + 10 * abs(T)
    w, a = get_matsubara_sum(T, E_typical)

    t3 = np.kron(S_z, S_0)
    res = sol.solve_many(omega=w)

    t3 = np.diag([1, 1, -1, -1])
    J = tr(res.J @ t3)
    J = -1j * pi * (J * a[:, None, None, None]).sum(axis=0)

    Jx = (J[:, :, 0] + J[:, :, 2]).real / 2
    Jy = (J[:, :, 1] + J[:, :, 3]).real / 2

    return (sol.x, sol.y, J, Jx, Jy)


def Gamma_to_alpha(Gamma_DP, Gamma_ST):
    r"""
    Transform from (Gamma_DP, Gamma_ST) to (alpha_soc, eta).

    .. math::

       \Gamma_{DP} &= 2 \alpha^2 p_F^2 \tau \\
       \Gamma_{ST} &= \Gamma_{DP} \frac{\alpha\tau}{\xi_0}

    and :math:`\xi_0^2 = D/(2\pi\Delta_0)`.

    The length unit is :math:`L_0 = \sqrt{D/E_0}` where
    :math:`E_0` is the energy unit.
    """

    alpha_soc = (Gamma_DP / 4) ** 0.5
    eta = Gamma_ST / (4 * np.sqrt(2 * pi) * alpha_soc**3)
    return alpha_soc, eta


def do(W_xi=6, multin=False):
    T = 0.1
    Gamma_DP = 10
    Gamma_ST = 1
    h = -10.0
    phi = np.linspace(0, 2 * pi, 37)
    xi = 1 / np.sqrt(2 * pi)
    L = np.array([0.1, 1.5, 2, 3]) * xi
    W = W_xi * xi

    alpha, eta = Gamma_to_alpha(Gamma_DP, Gamma_ST)

    res = j(
        T, h, phi[None, :], alpha_soc=alpha, eta=eta, L=L[:, None], W=W, n=10, mem=mem
    )

    if multin:
        mphi = phi[::2]
        res0 = j(
            T,
            h,
            mphi[None, :],
            alpha_soc=alpha,
            eta=eta,
            L=L[:, None],
            W=W,
            n=5,
            mem=mem,
        )
        res1 = j(
            T,
            h,
            mphi[None, :],
            alpha_soc=alpha,
            eta=eta,
            L=L[:, None],
            W=W,
            n=20,
            mem=mem,
        )

    Jx_mean = np.asarray(
        [Res(*x).Jx[1:-1].sum(axis=1).mean(axis=0) for x in res.flat]
    ).reshape(res.shape)

    if multin:
        Jx0_mean = np.asarray(
            [Res(*x).Jx[1:-1].sum(axis=1).mean(axis=0) for x in res0.flat]
        ).reshape(res0.shape)

        Jx1_mean = np.asarray(
            [Res(*x).Jx[1:-1].sum(axis=1).mean(axis=0) for x in res1.flat]
        ).reshape(res1.shape)

    fig, axs = plt.subplots(1, 2, layout="compressed")

    ax = axs[0]
    ax.plot(phi / pi, Jx_mean.T / abs(Jx_mean).max(axis=1))
    if multin:
        ax.plot(mphi / pi, Jx0_mean.T / abs(Jx0_mean).max(axis=1), "k:", alpha=0.25)
        ax.plot(mphi / pi, Jx1_mean.T / abs(Jx1_mean).max(axis=1), "k:")
    ax.set_xlabel(r"$\varphi / \pi$")
    ax.set_ylabel(r"$I / I_{\mathrm{max}}$")
    ax.legend(L / xi, title=r"$L/\xi$", loc="lower right")

    def eff(Jx):
        Jm = Jx.min(axis=1)
        Jp = Jx.max(axis=1)
        return (abs(Jp) - abs(Jm)) / (abs(Jp) + abs(Jm))

    ax = axs[1]
    ax.plot(L / xi, 100 * eff(Jx_mean))
    if multin:
        ax.plot(L / xi, 100 * eff(Jx0_mean), "k:", alpha=0.25)
        ax.plot(L / xi, 100 * eff(Jx1_mean), "k:")
    ax.set_xlabel(r"$L / \xi$")
    ax.set_ylabel(r"$\eta$  [%]")

    fig.suptitle(
        rf"$W = {W/xi} \xi_0$ $\Gamma_{{DP}} = {Gamma_DP} \Delta_0$, $\Gamma_{{ST}} = {Gamma_ST} \Delta_0$   ($\tilde{{\eta}} = {eta:.3g}$, $\tilde{{\alpha}} = {alpha:.3g}$)"
    )

    fig.savefig("cpr_sns.pdf")


def main():
    do()


if __name__ == "__main__":
    main()