Skip to content
Snippets Groups Projects
Commit 1f06fbb6 authored by patavirt's avatar patavirt
Browse files

tests: fix numdifftools usage

parent 5c0e8cbd
No related branches found
No related tags found
No related merge requests found
...@@ -87,21 +87,12 @@ def test_result_S_J0(omega, h, terminals): ...@@ -87,21 +87,12 @@ def test_result_S_J0(omega, h, terminals):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"omega,use_numdifftools", "omega,use_numdifftools",
[ [(w, t) for w in (-0.987, 0.987) for t in (False, True)],
(w, t)
for w in (-0.987, 0.987)
for t in [
False,
pytest.param(
True,
marks=pytest.mark.skipif(
numdifftools is None, reason="numdifftools not installed"
),
),
]
],
) )
def test_gauge_invariance(omega, use_numdifftools): def test_gauge_invariance(omega, use_numdifftools):
if use_numdifftools and numdifftools is None:
pytest.skip(reason="numdifftools not installed")
nx = 10 nx = 10
ny = 7 ny = 7
h = 0.4 h = 0.4
...@@ -198,8 +189,15 @@ def test_gauge_invariance(omega, use_numdifftools): ...@@ -198,8 +189,15 @@ def test_gauge_invariance(omega, use_numdifftools):
if use_numdifftools: if use_numdifftools:
# More accurate numdiff: the equality should hold to high # More accurate numdiff: the equality should hold to high
# numerical accuracy # numerical accuracy
D = numdifftools.Derivative(S_s) def rS_s(ds):
dS_ds = D(0.0) / 16 return S_s(ds).real
def iS_s(ds):
return S_s(ds).imag
rD = numdifftools.Derivative(rS_s)
iD = numdifftools.Derivative(iS_s)
dS_ds = complex(rD(0.0), iD(0.0)) / 16
assert_allclose(div_J, dS_ds, rtol=1e-12) assert_allclose(div_J, dS_ds, rtol=1e-12)
# The derivative must also match the gradient # The derivative must also match the gradient
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment