Skip to content
Snippets Groups Projects
Commit 6dae7ef2 authored by patavirt's avatar patavirt
Browse files

Add vectorize_parallel and use in example

parent 013fc4a4
No related branches found
No related tags found
No related merge requests found
......@@ -12,8 +12,7 @@ import logging
import usadelndsoc
from usadelndsoc.matsubara import get_matsubara_sum
from usadelndsoc.solver import *
usadelndsoc.logger.setLevel(logging.WARNING)
from usadelndsoc.util import vectorize_parallel
mem = joblib.Memory("cache")
......@@ -45,8 +44,10 @@ def get_solver(soc_alpha, eta, L=10, h=0.5, D=1.0, n=5):
return sol
@mem.cache
@vectorize_parallel(returns_object=True, noarray=True)
def j_anom(T, h, n=5, alpha_soc=0.01, perturbative=False):
usadelndsoc.logger.setLevel(logging.WARNING)
# Solver without SOC
if perturbative:
sol0 = get_solver(soc_alpha=0, eta=0, h=h)
......@@ -101,20 +102,15 @@ def j_anom(T, h, n=5, alpha_soc=0.01, perturbative=False):
return Jtot, Jantot
j_anom = np.vectorize(
j_anom,
signature="(),()->(m,n,k,p,q),()",
excluded=["n", "alpha_soc", "perturbative"],
)
def main():
T = 0.02
alpha_soc = 0.05
Da2 = alpha_soc**2
h = np.linspace(0, 1.75, 29)
j, j_an = j_anom(T, h, alpha_soc=alpha_soc, perturbative=False)
res = j_anom(T, h, alpha_soc=alpha_soc, perturbative=False, mem=mem)
j = np.asarray([x[0] for x in res])
j_an = np.asarray([x[1] for x in res])
t3 = np.diag([1, 1, -1, -1])
j = tr(j @ t3)[:, 2, 2, UP]
......
......@@ -14,7 +14,8 @@ py3.install_sources(
'core.py',
'matsubara.py',
'plotting.py',
'solver.py'
'solver.py',
'util.py',
],
pure : false,
subdir : 'usadelndsoc',
......
# -*- mode:python; coding: utf-8; eval: (blacken-mode) -*-
"""
Utility functions
"""
# Author: Pauli Virtanen
# License: GNU Affero General Public License, see LICENSE.txt for details
import re
import sys
import functools
import joblib
import inspect
import collections
import textwrap
if sys.version_info[0] < 3:
raise RuntimeError("Python 3 required")
import numpy as np
def call_with_params(func, argdict, **kwargs):
"""
Call function with given kwargs, supplemented by arguments
from `argdict` for those keys that appear in function signature.
Same as ``func(**dict(argdict, **kwargs))``, but ignoring keys
in `argdict` that do not match the signature of the function.
"""
sig = inspect.signature(func)
for p in sig.parameters.values():
if p.kind in (
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY,
):
if p.name in argdict and p.name not in kwargs:
kwargs[p.name] = argdict[p.name]
elif p.kind == inspect.Parameter.VAR_KEYWORD:
raise ValueError("Function has arbitrary kwargs")
return func(**kwargs)
class _ParallelSafeMemFunc:
"""
Function that updates inspect.linecache when called.
Required to make joblib.Memory work with parallel execution
with IPython cells.
"""
def __init__(self, func, mem):
self._orig_func = func
self._func = mem.cache(func)
inspect.linecache.checkcache()
inspect.getsourcelines(func)
self._linecache = inspect.linecache.cache
def __call__(self, *args, **kwargs):
if self._linecache is not None:
inspect.linecache.cache = self._linecache
self._linecache = None
self._old_checkcache = inspect.linecache.checkcache
try:
inspect.linecache.checkcache = lambda fn=None: None
return self._func(*args, **kwargs)
finally:
inspect.linecache.checkcache = self._old_checkcache
def _unarray(x):
if isinstance(x, np.ndarray):
return x.item()
return x
def vectorize_parallel(
func=None,
backend=None,
n_jobs=-1,
returns_object=False,
batch_size="auto",
included=None,
excluded=None,
noarray=False,
):
if func is None:
def deco(func):
return vectorize_parallel(
func,
backend=backend,
n_jobs=n_jobs,
returns_object=returns_object,
batch_size=batch_size,
included=included,
excluded=None,
noarray=noarray,
)
return deco
argspec = inspect.signature(func)
excluded = list(excluded) if excluded else []
excluded += ["mem", "parallel", "verbose"]
if included is not None:
excluded += [
x.name for x in argspec.parameters.values() if x.name not in included
]
arg_names = [
x.name
for x in argspec.parameters.values()
if x.kind
in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)
]
vararg_names = [
x.name
for x in argspec.parameters.values()
if x.kind == inspect.Parameter.VAR_POSITIONAL
]
if vararg_names:
vararg_name = vararg_names[0]
else:
vararg_name = None
star_kwargs_name = None
for x in argspec.parameters.values():
if x.kind == inspect.Parameter.VAR_KEYWORD:
star_kwargs_name = x.name
break
nonvec_arg_names = []
for name in list(arg_names):
if name in excluded:
arg_names.remove(name)
nonvec_arg_names.append(name)
if noarray:
unarray = _unarray
unarray_many = lambda x: map(_unarray, x)
else:
unarray = lambda x: x
unarray_many = lambda x: x
@functools.wraps(func)
def wrapper(*args, **kwargs):
parallel = kwargs.get("parallel", True)
if "parallel" not in nonvec_arg_names:
kwargs.pop("parallel", None)
mem = kwargs.get("mem", None)
if "mem" not in nonvec_arg_names:
kwargs.pop("mem", None)
verbose = kwargs.get("verbose", 0)
if "verbose" not in nonvec_arg_names:
kwargs.pop("verbose", None)
if mem is not None:
filename = inspect.getsourcefile(func)
if parallel and filename.startswith("<") and filename.endswith(">"):
# Lambdas etc. require saving inspect.linecache
call_func = _ParallelSafeMemFunc(func, mem)
else:
call_func = mem.cache(func)
else:
call_func = func
boundspec = argspec.bind(*args, **kwargs)
boundspec.apply_defaults()
callargs = boundspec.arguments
args = tuple(callargs.pop(name) for name in arg_names)
if vararg_name:
args += callargs.pop(vararg_name)
if star_kwargs_name is not None:
extra_kwargs = callargs.pop(star_kwargs_name)
callargs.update(extra_kwargs)
v = tuple(map(np.asarray, args))
v_broadcast = np.broadcast_arrays(*v)
if v_broadcast[0].size > 1:
it = np.nditer(v, ["refs_ok"])
if isinstance(verbose, bool) and verbose:
try:
import tqdm
it = tqdm.tqdm(it, total=it.itersize)
verbose = 0
except ImportError:
pass
if parallel:
delayed = joblib.delayed
else:
delayed = lambda x: x
# Parallelize
if len(v) == 1:
jobs = (delayed(call_func)(unarray(w), **callargs) for w in it)
else:
jobs = (delayed(call_func)(*unarray_many(w), **callargs) for w in it)
if parallel:
results = joblib.Parallel(
n_jobs=n_jobs,
backend=backend,
batch_size=batch_size,
verbose=verbose,
)(jobs)
else:
results = list(jobs)
else:
# Single point
it = np.nditer(v, ["refs_ok"])
if len(v) == 1:
results = [call_func(unarray(w), **callargs) for w in it]
else:
results = [call_func(*unarray_many(w), **callargs) for w in it]
if not returns_object:
results = np.array(results)
else:
results_arr = np.zeros(len(results), dtype=object)
results_arr[:] = results
results = results_arr
del results_arr
results_shape = results.shape[1:]
if results.ndim > 0:
results = np.rollaxis(results, 0, results.ndim)
return results.reshape(results_shape + v_broadcast[0].shape)
return wrapper
def initialize_joblib_function_hashing():
"""
Make joblib hashing sensitive to function code changes.
"""
import types, dis
from io import StringIO
def new_save_global(self, obj, name=None):
if hasattr(obj, "__code__"):
out = StringIO()
dis.dis(obj.__code__, file=out)
self.save(out.getvalue())
return old_save_global(self, obj, name=name)
old_save_global = joblib.hashing.Hasher.save_global
joblib.hashing.Hasher.dispatch[types.FunctionType] = new_save_global
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment