From 1f06fbb67d56c6d4e7be39088d1d5bef2e0dba4c Mon Sep 17 00:00:00 2001
From: Pauli Virtanen <pauli.t.virtanen@jyu.fi>
Date: Thu, 7 Dec 2023 17:06:33 +0200
Subject: [PATCH] tests: fix numdifftools usage

---
 tests/test_basic.py | 28 +++++++++++++---------------
 1 file changed, 13 insertions(+), 15 deletions(-)

diff --git a/tests/test_basic.py b/tests/test_basic.py
index caded57..a5a1c8a 100644
--- a/tests/test_basic.py
+++ b/tests/test_basic.py
@@ -87,21 +87,12 @@ def test_result_S_J0(omega, h, terminals):
 
 @pytest.mark.parametrize(
     "omega,use_numdifftools",
-    [
-        (w, t)
-        for w in (-0.987, 0.987)
-        for t in [
-            False,
-            pytest.param(
-                True,
-                marks=pytest.mark.skipif(
-                    numdifftools is None, reason="numdifftools not installed"
-                ),
-            ),
-        ]
-    ],
+    [(w, t) for w in (-0.987, 0.987) for t in (False, True)],
 )
 def test_gauge_invariance(omega, use_numdifftools):
+    if use_numdifftools and numdifftools is None:
+        pytest.skip(reason="numdifftools not installed")
+
     nx = 10
     ny = 7
     h = 0.4
@@ -198,8 +189,15 @@ def test_gauge_invariance(omega, use_numdifftools):
     if use_numdifftools:
         # More accurate numdiff: the equality should hold to high
         # numerical accuracy
-        D = numdifftools.Derivative(S_s)
-        dS_ds = D(0.0) / 16
+        def rS_s(ds):
+            return S_s(ds).real
+
+        def iS_s(ds):
+            return S_s(ds).imag
+
+        rD = numdifftools.Derivative(rS_s)
+        iD = numdifftools.Derivative(iS_s)
+        dS_ds = complex(rD(0.0), iD(0.0)) / 16
         assert_allclose(div_J, dS_ds, rtol=1e-12)
 
     # The derivative must also match the gradient
-- 
GitLab