From 6dae7ef2a11cd91e952813236d44fdc27f29a604 Mon Sep 17 00:00:00 2001
From: Pauli Virtanen <pauli.t.virtanen@jyu.fi>
Date: Fri, 9 Sep 2022 11:14:07 +0300
Subject: [PATCH] Add vectorize_parallel and use in example

---
 examples/bulk_j_an.py   |  18 ++-
 usadelndsoc/meson.build |   3 +-
 usadelndsoc/util.py     | 255 ++++++++++++++++++++++++++++++++++++++++
 3 files changed, 264 insertions(+), 12 deletions(-)
 create mode 100644 usadelndsoc/util.py

diff --git a/examples/bulk_j_an.py b/examples/bulk_j_an.py
index 139df0b..2858544 100644
--- a/examples/bulk_j_an.py
+++ b/examples/bulk_j_an.py
@@ -12,8 +12,7 @@ import logging
 import usadelndsoc
 from usadelndsoc.matsubara import get_matsubara_sum
 from usadelndsoc.solver import *
-
-usadelndsoc.logger.setLevel(logging.WARNING)
+from usadelndsoc.util import vectorize_parallel
 
 mem = joblib.Memory("cache")
 
@@ -45,8 +44,10 @@ def get_solver(soc_alpha, eta, L=10, h=0.5, D=1.0, n=5):
     return sol
 
 
-@mem.cache
+@vectorize_parallel(returns_object=True, noarray=True)
 def j_anom(T, h, n=5, alpha_soc=0.01, perturbative=False):
+    usadelndsoc.logger.setLevel(logging.WARNING)
+
     # Solver without SOC
     if perturbative:
         sol0 = get_solver(soc_alpha=0, eta=0, h=h)
@@ -101,20 +102,15 @@ def j_anom(T, h, n=5, alpha_soc=0.01, perturbative=False):
     return Jtot, Jantot
 
 
-j_anom = np.vectorize(
-    j_anom,
-    signature="(),()->(m,n,k,p,q),()",
-    excluded=["n", "alpha_soc", "perturbative"],
-)
-
-
 def main():
     T = 0.02
     alpha_soc = 0.05
     Da2 = alpha_soc**2
     h = np.linspace(0, 1.75, 29)
 
-    j, j_an = j_anom(T, h, alpha_soc=alpha_soc, perturbative=False)
+    res = j_anom(T, h, alpha_soc=alpha_soc, perturbative=False, mem=mem)
+    j = np.asarray([x[0] for x in res])
+    j_an = np.asarray([x[1] for x in res])
 
     t3 = np.diag([1, 1, -1, -1])
     j = tr(j @ t3)[:, 2, 2, UP]
diff --git a/usadelndsoc/meson.build b/usadelndsoc/meson.build
index f9c620a..3da9aed 100644
--- a/usadelndsoc/meson.build
+++ b/usadelndsoc/meson.build
@@ -14,7 +14,8 @@ py3.install_sources(
       'core.py',
       'matsubara.py',
       'plotting.py',
-      'solver.py'
+      'solver.py',
+      'util.py',
   ],
   pure : false,
   subdir : 'usadelndsoc',
diff --git a/usadelndsoc/util.py b/usadelndsoc/util.py
new file mode 100644
index 0000000..15766c1
--- /dev/null
+++ b/usadelndsoc/util.py
@@ -0,0 +1,255 @@
+# -*- mode:python; coding: utf-8; eval: (blacken-mode) -*-
+"""
+Utility functions
+"""
+# Author: Pauli Virtanen
+# License: GNU Affero General Public License, see LICENSE.txt for details
+
+import re
+import sys
+import functools
+import joblib
+import inspect
+import collections
+import textwrap
+
+if sys.version_info[0] < 3:
+    raise RuntimeError("Python 3 required")
+
+import numpy as np
+
+
+def call_with_params(func, argdict, **kwargs):
+    """
+    Call function with given kwargs, supplemented by arguments
+    from `argdict` for those keys that appear in function signature.
+
+    Same as ``func(**dict(argdict, **kwargs))``, but ignoring keys
+    in `argdict` that do not match the signature of the function.
+    """
+    sig = inspect.signature(func)
+    for p in sig.parameters.values():
+        if p.kind in (
+            inspect.Parameter.POSITIONAL_OR_KEYWORD,
+            inspect.Parameter.KEYWORD_ONLY,
+        ):
+            if p.name in argdict and p.name not in kwargs:
+                kwargs[p.name] = argdict[p.name]
+        elif p.kind == inspect.Parameter.VAR_KEYWORD:
+            raise ValueError("Function has arbitrary kwargs")
+    return func(**kwargs)
+
+
+class _ParallelSafeMemFunc:
+    """
+    Function that updates inspect.linecache when called.
+    Required to make joblib.Memory work with parallel execution
+    with IPython cells.
+    """
+
+    def __init__(self, func, mem):
+        self._orig_func = func
+        self._func = mem.cache(func)
+        inspect.linecache.checkcache()
+        inspect.getsourcelines(func)
+        self._linecache = inspect.linecache.cache
+
+    def __call__(self, *args, **kwargs):
+        if self._linecache is not None:
+            inspect.linecache.cache = self._linecache
+            self._linecache = None
+
+        self._old_checkcache = inspect.linecache.checkcache
+        try:
+            inspect.linecache.checkcache = lambda fn=None: None
+            return self._func(*args, **kwargs)
+        finally:
+            inspect.linecache.checkcache = self._old_checkcache
+
+
+def _unarray(x):
+    if isinstance(x, np.ndarray):
+        return x.item()
+    return x
+
+
+def vectorize_parallel(
+    func=None,
+    backend=None,
+    n_jobs=-1,
+    returns_object=False,
+    batch_size="auto",
+    included=None,
+    excluded=None,
+    noarray=False,
+):
+    if func is None:
+
+        def deco(func):
+            return vectorize_parallel(
+                func,
+                backend=backend,
+                n_jobs=n_jobs,
+                returns_object=returns_object,
+                batch_size=batch_size,
+                included=included,
+                excluded=None,
+                noarray=noarray,
+            )
+
+        return deco
+
+    argspec = inspect.signature(func)
+
+    excluded = list(excluded) if excluded else []
+    excluded += ["mem", "parallel", "verbose"]
+    if included is not None:
+        excluded += [
+            x.name for x in argspec.parameters.values() if x.name not in included
+        ]
+
+    arg_names = [
+        x.name
+        for x in argspec.parameters.values()
+        if x.kind
+        in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)
+    ]
+    vararg_names = [
+        x.name
+        for x in argspec.parameters.values()
+        if x.kind == inspect.Parameter.VAR_POSITIONAL
+    ]
+    if vararg_names:
+        vararg_name = vararg_names[0]
+    else:
+        vararg_name = None
+
+    star_kwargs_name = None
+    for x in argspec.parameters.values():
+        if x.kind == inspect.Parameter.VAR_KEYWORD:
+            star_kwargs_name = x.name
+            break
+
+    nonvec_arg_names = []
+    for name in list(arg_names):
+        if name in excluded:
+            arg_names.remove(name)
+            nonvec_arg_names.append(name)
+
+    if noarray:
+        unarray = _unarray
+        unarray_many = lambda x: map(_unarray, x)
+    else:
+        unarray = lambda x: x
+        unarray_many = lambda x: x
+
+    @functools.wraps(func)
+    def wrapper(*args, **kwargs):
+        parallel = kwargs.get("parallel", True)
+        if "parallel" not in nonvec_arg_names:
+            kwargs.pop("parallel", None)
+
+        mem = kwargs.get("mem", None)
+        if "mem" not in nonvec_arg_names:
+            kwargs.pop("mem", None)
+
+        verbose = kwargs.get("verbose", 0)
+        if "verbose" not in nonvec_arg_names:
+            kwargs.pop("verbose", None)
+
+        if mem is not None:
+            filename = inspect.getsourcefile(func)
+            if parallel and filename.startswith("<") and filename.endswith(">"):
+                # Lambdas etc. require saving inspect.linecache
+                call_func = _ParallelSafeMemFunc(func, mem)
+            else:
+                call_func = mem.cache(func)
+        else:
+            call_func = func
+
+        boundspec = argspec.bind(*args, **kwargs)
+        boundspec.apply_defaults()
+        callargs = boundspec.arguments
+        args = tuple(callargs.pop(name) for name in arg_names)
+        if vararg_name:
+            args += callargs.pop(vararg_name)
+
+        if star_kwargs_name is not None:
+            extra_kwargs = callargs.pop(star_kwargs_name)
+            callargs.update(extra_kwargs)
+
+        v = tuple(map(np.asarray, args))
+        v_broadcast = np.broadcast_arrays(*v)
+
+        if v_broadcast[0].size > 1:
+            it = np.nditer(v, ["refs_ok"])
+
+            if isinstance(verbose, bool) and verbose:
+                try:
+                    import tqdm
+
+                    it = tqdm.tqdm(it, total=it.itersize)
+                    verbose = 0
+                except ImportError:
+                    pass
+
+            if parallel:
+                delayed = joblib.delayed
+            else:
+                delayed = lambda x: x
+
+            # Parallelize
+            if len(v) == 1:
+                jobs = (delayed(call_func)(unarray(w), **callargs) for w in it)
+            else:
+                jobs = (delayed(call_func)(*unarray_many(w), **callargs) for w in it)
+
+            if parallel:
+                results = joblib.Parallel(
+                    n_jobs=n_jobs,
+                    backend=backend,
+                    batch_size=batch_size,
+                    verbose=verbose,
+                )(jobs)
+            else:
+                results = list(jobs)
+        else:
+            # Single point
+            it = np.nditer(v, ["refs_ok"])
+            if len(v) == 1:
+                results = [call_func(unarray(w), **callargs) for w in it]
+            else:
+                results = [call_func(*unarray_many(w), **callargs) for w in it]
+
+        if not returns_object:
+            results = np.array(results)
+        else:
+            results_arr = np.zeros(len(results), dtype=object)
+            results_arr[:] = results
+            results = results_arr
+            del results_arr
+        results_shape = results.shape[1:]
+
+        if results.ndim > 0:
+            results = np.rollaxis(results, 0, results.ndim)
+        return results.reshape(results_shape + v_broadcast[0].shape)
+
+    return wrapper
+
+
+def initialize_joblib_function_hashing():
+    """
+    Make joblib hashing sensitive to function code changes.
+    """
+    import types, dis
+    from io import StringIO
+
+    def new_save_global(self, obj, name=None):
+        if hasattr(obj, "__code__"):
+            out = StringIO()
+            dis.dis(obj.__code__, file=out)
+            self.save(out.getvalue())
+        return old_save_global(self, obj, name=name)
+
+    old_save_global = joblib.hashing.Hasher.save_global
+    joblib.hashing.Hasher.dispatch[types.FunctionType] = new_save_global
-- 
GitLab