diff --git a/usadelndsoc/solver.py b/usadelndsoc/solver.py
index 671fbd887649b9ffa7d4c46f01d75ca09e89cc32..75da18283e591a26de0ecf80e156de010c5b963a 100644
--- a/usadelndsoc/solver.py
+++ b/usadelndsoc/solver.py
@@ -53,6 +53,7 @@ __all__ = [
     "S_y",
     "S_z",
     "S_0",
+    "interpolate_J",
 ]
 
 _log = logging.getLogger(__name__)
@@ -1521,3 +1522,136 @@ def tr(M):
     Trace over last two dimensions
     """
     return np.trace(M, axis1=-2, axis2=-1)
+
+
+def interpolate_J(x0, y0, J, x, y, method="bilinear"):
+    r"""Interpolate (matrix) current components at given coordinates.
+
+    The current values are defined on links between cells, and not at
+    the cell centers, so special interpolation is necessary.
+
+    The interpolant inside each cell is linear :math:`J(x,y) = (J_0x + x J_1,
+    J_0y + y J_2)`. The component perpendicular to the cell faces is
+    piecewise linear, but the component parallel to cell faces is
+    piecewise constant.  The interpolant is divergenceless inside
+    cells, if the matrix current values do not indicate divergence.
+
+    Parameters
+    ----------
+    x0 : array, shape (nx,)
+        Cell x-coordinates
+    y0 : array, shape (ny,)
+        Cell y-coordinates
+    J : array, shape (nx, ny, 4, ...)
+        Matrix current, defined on links between cells.
+    x : array
+        x-coordinate to interpolate J at.
+    y : array
+        y-coordinate to interpolate J at.
+    method : {'linear', 'bilinear'}
+        Interpolation method.
+
+    Returns
+    -------
+    J_interp : array, shape broadcast_shape(x, y) + (2, ...)
+        Array containing x and y components of the current,
+        interpolated at the coordinates (x, y).
+
+    """
+
+    dx = x0[1] - x0[0]
+    dy = y0[1] - y0[0]
+    nx = len(x0)
+    ny = len(y0)
+
+    i = np.searchsorted(x0 + dx / 2, x).clip(0, nx - 1)
+    j = np.searchsorted(y0 + dy / 2, y).clip(0, ny - 1)
+
+    xc = x - x0[i]
+    yc = y - y0[j]
+
+    m = (
+        (x < x0[0] - dx / 2)
+        | (x > x0[-1] + dx / 2)
+        | (y < y0[0] - dy / 2)
+        | (y > y0[-1] + dy / 2)
+    )
+
+    Jshape = J.shape[3:]
+    sl = (Ellipsis,) + (None,) * len(Jshape)
+    slx = (Ellipsis, 0) + (None,) * len(Jshape)
+    sly = (Ellipsis, 1) + (None,) * len(Jshape)
+
+    I = np.empty(m.shape + (2,) + Jshape, dtype=J.dtype)
+
+    if method == "linear":
+        J0x = (J[i, j, 0] + J[i, j, 2]) / 2
+        J0y = (J[i, j, 1] + J[i, j, 3]) / 2
+
+        J1x = (J[i, j, 0] - J[i, j, 2]) / dx
+        J2y = (J[i, j, 1] - J[i, j, 3]) / dy
+
+        I[slx] = J0x + xc[sl] * J1x
+        I[sly] = J0y + yc[sl] * J2y
+    elif method == "bilinear":
+        # Bilinear polynomial hat functions at corner sites of cells.
+        # They are nonzero only on neighboring cells.
+        #
+        # On each of the edges, they are linear functions along the
+        # edge, and their average should coincide with the current
+        # perpendicular to the edge.
+
+        def M(n):
+            M = np.zeros((n, n + 1))
+            p = np.arange(n)
+            M[p, p] = 1 / 2
+            M[p, p + 1] = 1 / 2
+            return M
+
+        I0 = J[:, :, 0]
+        I1 = J[:, :, 1]
+        I2 = J[:, :, 2]
+        I3 = J[:, :, 3]
+
+        # Solve corner values Iyc: (Iyc(i,j) + Iyc(i+1,j))/2 == I3(i,j)
+        r = np.empty((nx, ny + 1) + I0.shape[2:], dtype=I0.dtype)
+        r[:, :ny] = I3
+        r[:, ny] = I1[:, -1]
+
+        rp = r.reshape(nx, -1)
+        z = np.linalg.lstsq(M(nx), rp, rcond=None)[0]
+        Iyt = z.reshape(nx + 1, *r.shape[1:])
+
+        # Solve corner values Ixc: (Ixc(i,j) + Ixc(i,j+1))/2 == I2(i,j)
+        r = np.empty((nx + 1, ny) + I0.shape[2:], dtype=I0.dtype)
+        r[:nx, :] = I2
+        r[nx, :] = I0[-1, :]
+
+        rp = np.moveaxis(r, 1, 0).reshape(ny, -1)
+        z = np.linalg.lstsq(M(ny), rp, rcond=None)[0]
+        Ixt = np.moveaxis(z.reshape(ny + 1, *((nx + 1,) + rp.shape[2:])), 0, 1)
+
+        # Interpolate with the bilinear hat functions
+        ax = 0.5 - xc / dx
+        bx = 0.5 + xc / dx
+
+        ay = 0.5 - yc / dy
+        by = 0.5 + yc / dy
+
+        I[slx] = (
+            ax * ay * Ixt[i, j]
+            + bx * ay * Ixt[i + 1, j]
+            + ax * by * Ixt[i, j + 1]
+            + bx * by * Ixt[i + 1, j + 1]
+        )
+        I[sly] = (
+            ax * ay * Iyt[i, j]
+            + bx * ay * Iyt[i + 1, j]
+            + ax * by * Iyt[i, j + 1]
+            + bx * by * Iyt[i + 1, j + 1]
+        )
+    else:
+        raise ValueError(f"Unknown {method=}")
+
+    I[m] = 0
+    return I