From 6dae7ef2a11cd91e952813236d44fdc27f29a604 Mon Sep 17 00:00:00 2001 From: Pauli Virtanen <pauli.t.virtanen@jyu.fi> Date: Fri, 9 Sep 2022 11:14:07 +0300 Subject: [PATCH] Add vectorize_parallel and use in example --- examples/bulk_j_an.py | 18 ++- usadelndsoc/meson.build | 3 +- usadelndsoc/util.py | 255 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 264 insertions(+), 12 deletions(-) create mode 100644 usadelndsoc/util.py diff --git a/examples/bulk_j_an.py b/examples/bulk_j_an.py index 139df0b..2858544 100644 --- a/examples/bulk_j_an.py +++ b/examples/bulk_j_an.py @@ -12,8 +12,7 @@ import logging import usadelndsoc from usadelndsoc.matsubara import get_matsubara_sum from usadelndsoc.solver import * - -usadelndsoc.logger.setLevel(logging.WARNING) +from usadelndsoc.util import vectorize_parallel mem = joblib.Memory("cache") @@ -45,8 +44,10 @@ def get_solver(soc_alpha, eta, L=10, h=0.5, D=1.0, n=5): return sol -@mem.cache +@vectorize_parallel(returns_object=True, noarray=True) def j_anom(T, h, n=5, alpha_soc=0.01, perturbative=False): + usadelndsoc.logger.setLevel(logging.WARNING) + # Solver without SOC if perturbative: sol0 = get_solver(soc_alpha=0, eta=0, h=h) @@ -101,20 +102,15 @@ def j_anom(T, h, n=5, alpha_soc=0.01, perturbative=False): return Jtot, Jantot -j_anom = np.vectorize( - j_anom, - signature="(),()->(m,n,k,p,q),()", - excluded=["n", "alpha_soc", "perturbative"], -) - - def main(): T = 0.02 alpha_soc = 0.05 Da2 = alpha_soc**2 h = np.linspace(0, 1.75, 29) - j, j_an = j_anom(T, h, alpha_soc=alpha_soc, perturbative=False) + res = j_anom(T, h, alpha_soc=alpha_soc, perturbative=False, mem=mem) + j = np.asarray([x[0] for x in res]) + j_an = np.asarray([x[1] for x in res]) t3 = np.diag([1, 1, -1, -1]) j = tr(j @ t3)[:, 2, 2, UP] diff --git a/usadelndsoc/meson.build b/usadelndsoc/meson.build index f9c620a..3da9aed 100644 --- a/usadelndsoc/meson.build +++ b/usadelndsoc/meson.build @@ -14,7 +14,8 @@ py3.install_sources( 'core.py', 'matsubara.py', 'plotting.py', - 'solver.py' + 'solver.py', + 'util.py', ], pure : false, subdir : 'usadelndsoc', diff --git a/usadelndsoc/util.py b/usadelndsoc/util.py new file mode 100644 index 0000000..15766c1 --- /dev/null +++ b/usadelndsoc/util.py @@ -0,0 +1,255 @@ +# -*- mode:python; coding: utf-8; eval: (blacken-mode) -*- +""" +Utility functions +""" +# Author: Pauli Virtanen +# License: GNU Affero General Public License, see LICENSE.txt for details + +import re +import sys +import functools +import joblib +import inspect +import collections +import textwrap + +if sys.version_info[0] < 3: + raise RuntimeError("Python 3 required") + +import numpy as np + + +def call_with_params(func, argdict, **kwargs): + """ + Call function with given kwargs, supplemented by arguments + from `argdict` for those keys that appear in function signature. + + Same as ``func(**dict(argdict, **kwargs))``, but ignoring keys + in `argdict` that do not match the signature of the function. + """ + sig = inspect.signature(func) + for p in sig.parameters.values(): + if p.kind in ( + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + ): + if p.name in argdict and p.name not in kwargs: + kwargs[p.name] = argdict[p.name] + elif p.kind == inspect.Parameter.VAR_KEYWORD: + raise ValueError("Function has arbitrary kwargs") + return func(**kwargs) + + +class _ParallelSafeMemFunc: + """ + Function that updates inspect.linecache when called. + Required to make joblib.Memory work with parallel execution + with IPython cells. + """ + + def __init__(self, func, mem): + self._orig_func = func + self._func = mem.cache(func) + inspect.linecache.checkcache() + inspect.getsourcelines(func) + self._linecache = inspect.linecache.cache + + def __call__(self, *args, **kwargs): + if self._linecache is not None: + inspect.linecache.cache = self._linecache + self._linecache = None + + self._old_checkcache = inspect.linecache.checkcache + try: + inspect.linecache.checkcache = lambda fn=None: None + return self._func(*args, **kwargs) + finally: + inspect.linecache.checkcache = self._old_checkcache + + +def _unarray(x): + if isinstance(x, np.ndarray): + return x.item() + return x + + +def vectorize_parallel( + func=None, + backend=None, + n_jobs=-1, + returns_object=False, + batch_size="auto", + included=None, + excluded=None, + noarray=False, +): + if func is None: + + def deco(func): + return vectorize_parallel( + func, + backend=backend, + n_jobs=n_jobs, + returns_object=returns_object, + batch_size=batch_size, + included=included, + excluded=None, + noarray=noarray, + ) + + return deco + + argspec = inspect.signature(func) + + excluded = list(excluded) if excluded else [] + excluded += ["mem", "parallel", "verbose"] + if included is not None: + excluded += [ + x.name for x in argspec.parameters.values() if x.name not in included + ] + + arg_names = [ + x.name + for x in argspec.parameters.values() + if x.kind + in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD) + ] + vararg_names = [ + x.name + for x in argspec.parameters.values() + if x.kind == inspect.Parameter.VAR_POSITIONAL + ] + if vararg_names: + vararg_name = vararg_names[0] + else: + vararg_name = None + + star_kwargs_name = None + for x in argspec.parameters.values(): + if x.kind == inspect.Parameter.VAR_KEYWORD: + star_kwargs_name = x.name + break + + nonvec_arg_names = [] + for name in list(arg_names): + if name in excluded: + arg_names.remove(name) + nonvec_arg_names.append(name) + + if noarray: + unarray = _unarray + unarray_many = lambda x: map(_unarray, x) + else: + unarray = lambda x: x + unarray_many = lambda x: x + + @functools.wraps(func) + def wrapper(*args, **kwargs): + parallel = kwargs.get("parallel", True) + if "parallel" not in nonvec_arg_names: + kwargs.pop("parallel", None) + + mem = kwargs.get("mem", None) + if "mem" not in nonvec_arg_names: + kwargs.pop("mem", None) + + verbose = kwargs.get("verbose", 0) + if "verbose" not in nonvec_arg_names: + kwargs.pop("verbose", None) + + if mem is not None: + filename = inspect.getsourcefile(func) + if parallel and filename.startswith("<") and filename.endswith(">"): + # Lambdas etc. require saving inspect.linecache + call_func = _ParallelSafeMemFunc(func, mem) + else: + call_func = mem.cache(func) + else: + call_func = func + + boundspec = argspec.bind(*args, **kwargs) + boundspec.apply_defaults() + callargs = boundspec.arguments + args = tuple(callargs.pop(name) for name in arg_names) + if vararg_name: + args += callargs.pop(vararg_name) + + if star_kwargs_name is not None: + extra_kwargs = callargs.pop(star_kwargs_name) + callargs.update(extra_kwargs) + + v = tuple(map(np.asarray, args)) + v_broadcast = np.broadcast_arrays(*v) + + if v_broadcast[0].size > 1: + it = np.nditer(v, ["refs_ok"]) + + if isinstance(verbose, bool) and verbose: + try: + import tqdm + + it = tqdm.tqdm(it, total=it.itersize) + verbose = 0 + except ImportError: + pass + + if parallel: + delayed = joblib.delayed + else: + delayed = lambda x: x + + # Parallelize + if len(v) == 1: + jobs = (delayed(call_func)(unarray(w), **callargs) for w in it) + else: + jobs = (delayed(call_func)(*unarray_many(w), **callargs) for w in it) + + if parallel: + results = joblib.Parallel( + n_jobs=n_jobs, + backend=backend, + batch_size=batch_size, + verbose=verbose, + )(jobs) + else: + results = list(jobs) + else: + # Single point + it = np.nditer(v, ["refs_ok"]) + if len(v) == 1: + results = [call_func(unarray(w), **callargs) for w in it] + else: + results = [call_func(*unarray_many(w), **callargs) for w in it] + + if not returns_object: + results = np.array(results) + else: + results_arr = np.zeros(len(results), dtype=object) + results_arr[:] = results + results = results_arr + del results_arr + results_shape = results.shape[1:] + + if results.ndim > 0: + results = np.rollaxis(results, 0, results.ndim) + return results.reshape(results_shape + v_broadcast[0].shape) + + return wrapper + + +def initialize_joblib_function_hashing(): + """ + Make joblib hashing sensitive to function code changes. + """ + import types, dis + from io import StringIO + + def new_save_global(self, obj, name=None): + if hasattr(obj, "__code__"): + out = StringIO() + dis.dis(obj.__code__, file=out) + self.save(out.getvalue()) + return old_save_global(self, obj, name=name) + + old_save_global = joblib.hashing.Hasher.save_global + joblib.hashing.Hasher.dispatch[types.FunctionType] = new_save_global -- GitLab