diff --git a/usadelndsoc/solver.py b/usadelndsoc/solver.py
index 5514e42402a511b5423fa67121400407b8cefc95..664b9098a33ccb49550a6bd2d32a33bc9e682aec 100644
--- a/usadelndsoc/solver.py
+++ b/usadelndsoc/solver.py
@@ -23,6 +23,8 @@ from scipy.sparse.linalg import (
 )
 from scipy.special import zeta
 from scipy.linalg import blas, signm
+import scipy.interpolate
+import textwrap
 
 try:
     from scikits import umfpack
@@ -1580,6 +1582,9 @@ def interpolate_J(x0, y0, J, x, y, method="bilinear"):
     nx = len(x0)
     ny = len(y0)
 
+    if not (np.allclose(np.diff(x0), dx) and np.allclose(np.diff(y0), dy)):
+        raise ValueError("Original grid must be regular.")
+
     i = np.searchsorted(x0 + dx / 2, x).clip(0, nx - 1)
     j = np.searchsorted(y0 + dy / 2, y).clip(0, ny - 1)
 
@@ -1595,8 +1600,8 @@ def interpolate_J(x0, y0, J, x, y, method="bilinear"):
 
     Jshape = J.shape[3:]
     sl = (Ellipsis,) + (None,) * len(Jshape)
-    slx = (Ellipsis, 0) + (None,) * len(Jshape)
-    sly = (Ellipsis, 1) + (None,) * len(Jshape)
+    slx = (Ellipsis, 0) + (slice(None),) * len(Jshape)
+    sly = (Ellipsis, 1) + (slice(None),) * len(Jshape)
 
     I = np.empty(m.shape + (2,) + Jshape, dtype=J.dtype)
 
@@ -1647,14 +1652,14 @@ def interpolate_J(x0, y0, J, x, y, method="bilinear"):
 
         rp = np.moveaxis(r, 1, 0).reshape(ny, -1)
         z = np.linalg.lstsq(M(ny), rp, rcond=None)[0]
-        Ixt = np.moveaxis(z.reshape(ny + 1, *((nx + 1,) + rp.shape[2:])), 0, 1)
+        Ixt = np.moveaxis(z.reshape(ny + 1, *((nx + 1,) + r.shape[2:])), 0, 1)
 
         # Interpolate with the bilinear hat functions
-        ax = 0.5 - xc / dx
-        bx = 0.5 + xc / dx
+        ax = (0.5 - xc / dx)[sl]
+        bx = (0.5 + xc / dx)[sl]
 
-        ay = 0.5 - yc / dy
-        by = 0.5 + yc / dy
+        ay = (0.5 - yc / dy)[sl]
+        by = (0.5 + yc / dy)[sl]
 
         I[slx] = (
             ax * ay * Ixt[i, j]
@@ -1673,3 +1678,80 @@ def interpolate_J(x0, y0, J, x, y, method="bilinear"):
 
     I[m] = 0
     return I
+
+
+def interpolate_S(x0, y0, S, x, y, method="bilinear"):
+    r"""Interpolate densities at given coordinates.
+
+    Parameters
+    ----------
+    x0 : array, shape (nx,)
+        Cell x-coordinates
+    y0 : array, shape (ny,)
+        Cell y-coordinates
+    S : array, shape (nx, ny, ...)
+        Matrix current, defined on links between cells.
+    x : array
+        x-coordinate to interpolate J at.
+    y : array
+        y-coordinate to interpolate J at.
+    method : {'bilinear'}
+        Interpolation method.
+
+        - ``bilinear``: :math:`S` is continuous piecewise linear
+          functions both along :math:`x` and :math:`y`.
+
+    Returns
+    -------
+    S_interp : array, shape broadcast_shape(x, y) + ...
+        Array containing the density interpolated at
+        the coordinates (x, y).
+
+    Notes
+    -----
+    This is just simple bilinear interpolation on a grid, with
+    zero gradient enforced at grid edges.
+
+    """
+
+    dx = x0[1] - x0[0]
+    dy = y0[1] - y0[0]
+    nx = len(x0)
+    ny = len(y0)
+
+    if not (np.allclose(np.diff(x0), dx) and np.allclose(np.diff(y0), dy)):
+        raise ValueError("Original grid must be regular.")
+
+    try:
+        meth = {"bilinear": "linear"}[method]
+    except KeyError:
+        raise ValueError(f"Unknown {method=}")
+
+    x, y = np.broadcast_arrays(x, y)
+    xy = np.moveaxis([x, y], 0, -1)
+
+    m = (
+        (x < x0[0] - dx / 2)
+        | (x > x0[-1] + dx / 2)
+        | (y < y0[0] - dy / 2)
+        | (y > y0[-1] + dy / 2)
+    )
+
+    x1 = np.r_[x0[0] - dx / 2, x0, x0[-1] + dx / 2]
+    y1 = np.r_[y0[0] - dy / 2, y0, y0[-1] + dy / 2]
+
+    # Pad to constant value at boundaries
+    s = np.zeros((nx + 2, ny + 2) + S.shape[2:], dtype=S.dtype)
+    s[1:-1, 1:-1] = S
+    s[1:-1, 0] = S[:, 0]
+    s[1:-1, -1] = S[:, -1]
+    s[0, 1:-1] = S[0, :]
+    s[-1, 1:-1] = S[-1, :]
+    s[0, 0] = S[0, 0]
+    s[0, -1] = S[0, -1]
+    s[-1, 0] = S[-1, 0]
+    s[-1, -1] = S[-1, -1]
+
+    return scipy.interpolate.interpn(
+        (x1, y1), s, xy, method=meth, bounds_error=False, fill_value=0
+    )