Skip to content
Snippets Groups Projects
bulk_j_an.py 3.15 KiB
# -*- coding:utf-8; eval: (blacken-mode) -*-
"""
Calculate anomalous current in bulk case
"""
import joblib
import numpy as np
from numpy import pi
import matplotlib.pyplot as plt
from scipy.linalg import expm
import logging

import usadelndsoc
from usadelndsoc.matsubara import get_matsubara_sum
from usadelndsoc.solver import *
from usadelndsoc.util import vectorize_parallel

mem = joblib.Memory("cache")


def get_solver(soc_alpha, eta, L=10, h=0.5, D=1.0, n=5):
    sol = Solver(nx=n, ny=n)
    sol.mask[...] = MASK_NONE
    sol.Lx = L
    sol.Ly = L

    sol.Omega[...] = 0
    sol.Delta[:, :] = np.eye(2)

    sol.D = D
    sol.eta = eta

    nx, ny = sol.shape
    dx = sol.Lx / nx
    dy = sol.Ly / ny

    Ax = soc_alpha * np.kron(S_0, S_y)
    Ay = -soc_alpha * np.kron(S_0, S_x)

    sol.Ux[...] = expm(1j * dx * Ax)
    sol.Uy[...] = expm(1j * dy * Ay)

    sol.Omega[...] += h * np.kron(S_z, S_x)

    return sol


@vectorize_parallel(returns_object=True, noarray=True)
def j_anom(T, h, n=5, alpha_soc=0.01, perturbative=False):
    usadelndsoc.logger.setLevel(logging.WARNING)

    # Solver without SOC
    if perturbative:
        sol0 = get_solver(soc_alpha=0, eta=0, h=h, n=n)

    # Solver for evaluating current perturbatively in SOC
    eta = 1.0
    sol = get_solver(soc_alpha=alpha_soc, eta=eta, h=h, D=1.0, n=n)

    E_typical = 10 + 10 * T
    w, a = get_matsubara_sum(T, E_typical)

    t3 = np.kron(S_z, S_0)

    Phis = []
    Fs = []
    Gs = []
    Js = []
    Jans = []

    dd = (sol.Lx / sol.shape[0], sol.Ly / sol.shape[1])

    Jtot = 0
    Jantot = 0
    for ww, aa in zip(w, a):
        if perturbative:
            # Bulk solution without SOC field and S_H
            res = sol0.solve(omega=ww)
        else:
            # With SOC field and S_H
            res = sol.solve(omega=ww)

        # Evaluate current
        sol.omega = ww
        J = sol._core.grad_A(res.Phi.copy()).transpose(0, 1, 2, 4, 3) * (-1 / 16)
        J[:, :, LEFT] /= dd[1]  # current density: divide by cell width
        J[:, :, UP] /= dd[0]
        J[:, :, RIGHT] /= dd[1]
        J[:, :, DOWN] /= dd[0]
        Jtot += 1j * pi * aa * J

        # Analytic
        gp = (ww - 1j * h) / np.sqrt((ww - 1j * h) ** 2 + 1)
        gm = (ww + 1j * h) / np.sqrt((ww + 1j * h) ** 2 + 1)
        fp = 1 / np.sqrt((ww - 1j * h) ** 2 + 1)
        fm = 1 / np.sqrt((ww + 1j * h) ** 2 + 1)
        # XXX: check factors of 2
        Jy_an = (gp - gm) * (1 + gp * gm + fp * fm) * (-1j * alpha_soc**3) * eta
        Jantot += pi * aa * Jy_an

    return Jtot, Jantot


def main():
    T = 0.02
    alpha_soc = 0.05
    Da2 = alpha_soc**2
    h = np.linspace(0, 1.75, 79)

    res = j_anom(T, h, n=20, alpha_soc=alpha_soc, perturbative=False, mem=mem)
    j = np.asarray([x[0] for x in res])
    j_an = np.asarray([x[1] for x in res])

    t3 = np.diag([1, 1, -1, -1])
    j = tr(j @ t3)[:, 2, 2, UP]

    plt.plot(h, -j.real, "-", label=r"numerics")
    plt.plot(h, -j_an.real, "--", label=r"analytic")
    plt.title(rf"$T = {T} \Delta$, $D\alpha^2 = {Da2:.4} \Delta$")
    plt.xlabel(r"$h/\Delta$")
    plt.ylabel(r"$j_{\mathrm{an}}$")
    plt.legend()
    plt.savefig("bulk_j_an.pdf")


if __name__ == "__main__":
    main()