Skip to content
Snippets Groups Projects
Commit 1f06fbb6 authored by patavirt's avatar patavirt
Browse files

tests: fix numdifftools usage

parent 5c0e8cbd
No related branches found
No related tags found
No related merge requests found
......@@ -87,21 +87,12 @@ def test_result_S_J0(omega, h, terminals):
@pytest.mark.parametrize(
"omega,use_numdifftools",
[
(w, t)
for w in (-0.987, 0.987)
for t in [
False,
pytest.param(
True,
marks=pytest.mark.skipif(
numdifftools is None, reason="numdifftools not installed"
),
),
]
],
[(w, t) for w in (-0.987, 0.987) for t in (False, True)],
)
def test_gauge_invariance(omega, use_numdifftools):
if use_numdifftools and numdifftools is None:
pytest.skip(reason="numdifftools not installed")
nx = 10
ny = 7
h = 0.4
......@@ -198,8 +189,15 @@ def test_gauge_invariance(omega, use_numdifftools):
if use_numdifftools:
# More accurate numdiff: the equality should hold to high
# numerical accuracy
D = numdifftools.Derivative(S_s)
dS_ds = D(0.0) / 16
def rS_s(ds):
return S_s(ds).real
def iS_s(ds):
return S_s(ds).imag
rD = numdifftools.Derivative(rS_s)
iD = numdifftools.Derivative(iS_s)
dS_ds = complex(rD(0.0), iD(0.0)) / 16
assert_allclose(div_J, dS_ds, rtol=1e-12)
# The derivative must also match the gradient
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment