From 1f06fbb67d56c6d4e7be39088d1d5bef2e0dba4c Mon Sep 17 00:00:00 2001 From: Pauli Virtanen <pauli.t.virtanen@jyu.fi> Date: Thu, 7 Dec 2023 17:06:33 +0200 Subject: [PATCH] tests: fix numdifftools usage --- tests/test_basic.py | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/tests/test_basic.py b/tests/test_basic.py index caded57..a5a1c8a 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -87,21 +87,12 @@ def test_result_S_J0(omega, h, terminals): @pytest.mark.parametrize( "omega,use_numdifftools", - [ - (w, t) - for w in (-0.987, 0.987) - for t in [ - False, - pytest.param( - True, - marks=pytest.mark.skipif( - numdifftools is None, reason="numdifftools not installed" - ), - ), - ] - ], + [(w, t) for w in (-0.987, 0.987) for t in (False, True)], ) def test_gauge_invariance(omega, use_numdifftools): + if use_numdifftools and numdifftools is None: + pytest.skip(reason="numdifftools not installed") + nx = 10 ny = 7 h = 0.4 @@ -198,8 +189,15 @@ def test_gauge_invariance(omega, use_numdifftools): if use_numdifftools: # More accurate numdiff: the equality should hold to high # numerical accuracy - D = numdifftools.Derivative(S_s) - dS_ds = D(0.0) / 16 + def rS_s(ds): + return S_s(ds).real + + def iS_s(ds): + return S_s(ds).imag + + rD = numdifftools.Derivative(rS_s) + iD = numdifftools.Derivative(iS_s) + dS_ds = complex(rD(0.0), iD(0.0)) / 16 assert_allclose(div_J, dS_ds, rtol=1e-12) # The derivative must also match the gradient -- GitLab