Skip to content
Snippets Groups Projects
Commit 1cda610f authored by patavirt's avatar patavirt
Browse files

solver: fix self-consistent iteration to do singlet correctly

parent 13ad50d0
No related branches found
No related tags found
No related merge requests found
...@@ -16,7 +16,7 @@ __all__ = [ ...@@ -16,7 +16,7 @@ __all__ = [
] ]
def plot_Delta_J_2d(x, y, Delta, J, ax=None): def plot_Delta_J_2d(x, y, Delta, xc, yc, Ix, Iy, ax=None):
""" """
Color plot of order parameter and supercurrents. Color plot of order parameter and supercurrents.
""" """
...@@ -24,15 +24,11 @@ def plot_Delta_J_2d(x, y, Delta, J, ax=None): ...@@ -24,15 +24,11 @@ def plot_Delta_J_2d(x, y, Delta, J, ax=None):
ax = plt.gca() ax = plt.gca()
m = pcolormesh_complex(x, y, Delta, ax=ax) m = pcolormesh_complex(x, y, Delta, ax=ax)
I = J.real
xc = (x[1:] + x[:-1]) / 2
yc = (y[1:] + y[:-1]) / 2
Ix = I[0, :-1, :-1, 0].T
Iy = I[1, :-1, :-1, 0].T
lw = np.hypot(Ix, Iy) lw = np.hypot(Ix, Iy)
lw = 6 * (lw / max(1.0, lw.max())) ** 0.5 lw = 6 * (lw / max(1.0, lw.max())) ** 0.5
ax.streamplot(xc, yc, Ix.T, Iy.T, color=(0.5, 0.5, 0.5), linewidth=lw.T)
ax.streamplot(xc, yc, Ix, Iy, color=(0.5, 0.5, 0.5), linewidth=lw)
ax.set_xlabel(r"$x$") ax.set_xlabel(r"$x$")
ax.set_ylabel(r"$y$") ax.set_ylabel(r"$y$")
plt.colorbar(m.phase, label=r"$\varphi\,/ \pi$", ax=ax) plt.colorbar(m.phase, label=r"$\varphi\,/ \pi$", ax=ax)
......
...@@ -937,6 +937,14 @@ def _with_workers(func): ...@@ -937,6 +937,14 @@ def _with_workers(func):
return wrapper return wrapper
def singlet(x):
return (x[..., 0, 0] + x[..., 1, 1]) / 2
def singlet_m(x):
return x[..., None, None] * S_0
@_with_workers @_with_workers
def _self_consistency( def _self_consistency(
solver, solver,
...@@ -967,12 +975,12 @@ def _self_consistency( ...@@ -967,12 +975,12 @@ def _self_consistency(
T = float(T) T = float(T)
T_c0 = np.asarray(T_c0) T_c0 = np.asarray(T_c0)
if T_c0.shape + (2, 2) != solver.Delta.shape: if T_c0.shape != solver.Delta.shape[:-2]:
raise ValueError("T_c0 must have compatible shape with Delta") raise ValueError("T_c0 must have compatible shape with Delta")
Delta_mask = (solver.mask == MASK_NONE) & (T_c0 > 0) & (T < T_c0) Delta_mask = (solver.mask == MASK_NONE) & (T_c0 > 0) & (T < T_c0)
z = solver.Delta[Delta_mask].copy().ravel().view(float) z = singlet(solver.Delta[Delta_mask]).ravel().view(float)
if constraint_fun is not None: if constraint_fun is not None:
constraint_x = np.asarray(constraint_x) constraint_x = np.asarray(constraint_x)
z = np.r_[constraint_x.astype(float), z] z = np.r_[constraint_x.astype(float), z]
...@@ -1010,13 +1018,13 @@ def _self_consistency( ...@@ -1010,13 +1018,13 @@ def _self_consistency(
return v.copy() return v.copy()
cx = z[: constraint_x.size].reshape(constraint_x.shape) cx = z[: constraint_x.size].reshape(constraint_x.shape)
Delta1 = solver.Delta.copy() Delta1 = singlet(solver.Delta)
Delta1[Delta_mask] = z[constraint_x.size :].view(complex).reshape(-1, 2, 2) Delta1[Delta_mask] = z[constraint_x.size :].view(complex).ravel()
if constraint_fun is not None: if constraint_fun is not None:
res_cons = constraint_fun(cx, Delta1).ravel() res_cons = constraint_fun(cx, Delta1).ravel()
# Constraint function updates terminal phases # Constraint function updates terminal phases
solver.Delta[...] = Delta1 solver.Delta[...] = singlet_m(Delta1)
res, J, m = _self_consistent_Delta_f( res, J, m = _self_consistent_Delta_f(
solver, Delta1, T, T_c0, workers=workers, **solver_kw solver, Delta1, T, T_c0, workers=workers, **solver_kw
...@@ -1088,10 +1096,8 @@ def _self_consistency( ...@@ -1088,10 +1096,8 @@ def _self_consistency(
y = np.linspace(-solver.Ly / 2, solver.Ly / 2, solver.shape[1]) y = np.linspace(-solver.Ly / 2, solver.Ly / 2, solver.shape[1])
plt.clf() plt.clf()
ddd = solver.Delta.copy() ddd = singlet(solver.Delta)
ddd[Delta_mask] = z[constraint_x.size :].view(complex).reshape(-1, 2, 2) ddd[Delta_mask] = z[constraint_x.size :].view(complex).ravel()
ddd = ddd[..., 0, 0]
if plot == "circle" or min(ddd.shape) == 1: if plot == "circle" or min(ddd.shape) == 1:
ddd = ddd.squeeze() ddd = ddd.squeeze()
...@@ -1113,7 +1119,7 @@ def _self_consistency( ...@@ -1113,7 +1119,7 @@ def _self_consistency(
success = False success = False
Delta = solver.Delta.copy() Delta = solver.Delta.copy()
Delta[Delta_mask] = z[constraint_x.size :].view(complex).reshape(-1, 2, 2) Delta[Delta_mask] = singlet_m(z[constraint_x.size :].view(complex).ravel())
cx = z[: constraint_x.size] cx = z[: constraint_x.size]
if success and (np.isnan(J).any() or np.isnan(Delta).any()): if success and (np.isnan(J).any() or np.isnan(Delta).any()):
...@@ -1140,7 +1146,7 @@ def _self_consistent_Delta_f( ...@@ -1140,7 +1146,7 @@ def _self_consistent_Delta_f(
w = w[len(w) // 2 :][::-1] w = w[len(w) // 2 :][::-1]
a = a[len(a) // 2 :][::-1] a = a[len(a) // 2 :][::-1]
rtot = np.zeros(solver.Delta.shape, dtype=complex) rtot = np.zeros(tuple(solver.shape), dtype=complex)
Jtot = np.zeros(tuple(solver.shape) + (4, 4, 4), dtype=complex) Jtot = np.zeros(tuple(solver.shape) + (4, 4, 4), dtype=complex)
mask1 = solver.mask == MASK_NONE mask1 = solver.mask == MASK_NONE
...@@ -1148,7 +1154,7 @@ def _self_consistent_Delta_f( ...@@ -1148,7 +1154,7 @@ def _self_consistent_Delta_f(
rtot[mask1 & ~mask2] = Delta[mask1 & ~mask2] rtot[mask1 & ~mask2] = Delta[mask1 & ~mask2]
mask = mask1 & mask2 mask = mask1 & mask2
solver.Delta[mask] = Delta[mask] solver.Delta[mask] = singlet_m(Delta[mask])
if workers is not None: if workers is not None:
jobs = [] jobs = []
...@@ -1167,7 +1173,7 @@ def _self_consistent_Delta_f( ...@@ -1167,7 +1173,7 @@ def _self_consistent_Delta_f(
rtot[mask] += rtotx[mask] rtot[mask] += rtotx[mask]
Jtot += Jtotx Jtot += Jtotx
rtot[mask] -= np.log(T_c0[mask, None, None] / T) * solver.Delta[mask] rtot[mask] -= np.log(T_c0[mask] / T) * singlet(solver.Delta[mask])
return rtot, Jtot, mask return rtot, Jtot, mask
...@@ -1200,7 +1206,7 @@ def _mp_one(args, solver=None, solver_kw=None): ...@@ -1200,7 +1206,7 @@ def _mp_one(args, solver=None, solver_kw=None):
for wx, ax in zip(w, a): for wx, ax in zip(w, a):
res = solver.solve(omega=wx, **solver_kw) res = solver.solve(omega=wx, **solver_kw)
r = solver.Delta.A / abs(wx) - res.F r = singlet(solver.Delta) / abs(wx) - singlet(res.F)
rtot += (2 * np.pi * ax) * r rtot += (2 * np.pi * ax) * r
Jtot += (-2j * np.pi * ax) * res.J Jtot += (-2j * np.pi * ax) * res.J
finally: finally:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment