Source code for gpe.utils

"""Various utilities used throughout the project"""

from __future__ import absolute_import, division, print_function, unicode_literals

import atexit
try:
    from collections import abc
except ImportError:
    import collections as abc
import contextlib
import decimal
import glob
import inspect
import logging
import math
import os.path
import shutil
import subprocess
import sys
import time
import traceback
import warnings

from six import string_types

import wrapt

import mmfutils.performance.fft
from mmfutils.contexts import NoInterrupt

from persist.objects import Archivable
from persist.archive import Archive

from pytimeode.evolvers import EvolverABM
from pytimeode.mixins import ArrayStateMixin
from pytimeode.interfaces import implementer, IStateForABMEvolvers

import numpy as np

from .interfaces import IExperiment, IStateDFT
from .mixins import StateMixin


__all__ = [
    "step",
    "x2_2",
    "good_Ns",
    "get_good_N",
    "get_smooth_transition",
    "Frames",
    "GPUHelper",
    "AsNumpyMixin",
    "PerformanceWarning",
    "IStateDFT",
    "StateWithExperimentMixin",
    "use_wisdom",
]

_DATA_DIR = "_data"


pauli_matrices = np.array([[[0, 1], [1, 0]], [[0, -1j], [1j, 0]], [[1, 0], [0, -1]]])

_l3 = np.zeros((3, 3, 3))
_l3[0, 1, 2] = _l3[1, 2, 0] = _l3[2, 0, 1] = 1
_l3[2, 1, 0] = _l3[1, 0, 2] = _l3[0, 2, 1] = -1

levi_civita = {2: np.array([[0, 1], [-1, 0]]), 3: _l3}
del _l3


[docs] class PerformanceWarning(Warning): """Warning for potential performance issues."""
[docs] def use_wisdom(**kw): """Load the fftw_wisdom file (if it exists) and save on exit.""" wisdom_context = mmfutils.performance.fft.fftw_wisdom(**kw) atexit.register(wisdom_context.__exit__) wisdom_context.__enter__()
[docs] def step(t, t1, alpha=3.0): r"""Smooth step function that goes from 0 at time ``t=0`` to 1 at time ``t=t1``. This step function is $C_\infty$: """ if t < 0.0: return 0.0 elif t < t1: return (1 + math.tanh(alpha * math.tan(math.pi * (2 * t / t1 - 1) / 2))) / 2 else: return 1.0
[docs] def x2_2(x, order=6): r"""Return a periodic approximation for $x^2/2$ with period $2\pi$. order == 0: 1-cos(x) order == 2: cos(x)**2/6 - 4*cos(x)/3 + 7/6 ... """ if order == 0: coeffs = [-1.0, 1.0] elif order == 2: coeffs = [1.0 / 6, -4.0 / 3, 7.0 / 6] elif order == 4: coeffs = [-2.0 / 45, 3.0 / 10, -22.0 / 15, 109.0 / 90] elif order == 6: coeffs = [1.0 / 70, -32.0 / 315, 27.0 / 70, -32.0 / 21, 386.0 / 315] else: raise NotImplementedError("Got order={}. Must be 0, 2, 4, or 6.".format(order)) return np.polyval(coeffs, np.cos(x))
def x_periodic(x, x0=0.8, p=3): """Return a new set of abscissa `xp` such that `V(xp)` will be smooth and periodic on the interval [-1, 1] if `V(x) = V(-x)` and `V(x)` is smooth. Parameters ---------- x0 : float Parameter that affects the smoothness of the transition. The resulting abscissa will range from [-x0, x0], so this will determine the magnitude of `V(x0)` at the boundaries. p : int Parameter affecting the smoothness. Should be odd (1 or 3). Examples -------- >>> x_periodic([-1.1, -1, 0, 1], x0=0.8) array([-0.88, -0.8 , 0. , 0.8 ]) """ a = 1.0 / x0**p xt = np.sign(x) * np.where( np.abs(x) < 1, np.maximum(0, (np.tanh(2 * a / np.pi * np.tan(np.pi * np.abs(x) ** p / 2.0)) / a)) ** (1.0 / p), np.abs(x) * x0, ) return xt
[docs] def good_Ns(Nmax=2**15): """Return a list of good N's for the FFT (powers of 2, 3, and 5).""" factors = [2, 3, 5] max_powers = np.ceil(np.log(Nmax) / np.log(factors)).astype(int) terms = np.meshgrid( *[_f ** np.arange(_p + 1) for _f, _p in zip(factors, max_powers)], sparse=True, indexing="ij", ) if False: # The following now fails with a VisibleDeprecationWarning: # "Creating an ndarray from ragged nested sequences (which is a # list-or-tuple of lists-or-tuples-or ndarrays with different # lengths or shapes) is deprecated. # ... # https://stackoverflow.com/a/65982550/1088938 # math.prod works for python >= 3.8 res = sorted(np.prod(terms, axis=0).ravel())[1:] else: res = 1 for term in terms: res = res * term res = sorted(res.ravel())[1:] return np.array([_n for _n in res if _n <= Nmax])
[docs] def get_good_N(N): """Get the lowest good size N greater than or equal to N for the FFT. Examples -------- >>> get_good_N(600) 600 >>> get_good_N(601) 625 """ Ns = good_Ns(2 * N) return int(Ns[np.where(Ns >= N)[0][0]])
def mem_str(bytes): """Return the memory usage in nice units. Examples -------- >>> [str(mem_str(_d)) for _d in ... [1.0, 1000, 2000, 2*1024**2, 3.1*1024**3, ... 4*1024**4, 5*1024**5, 6*1024**6]] ['1B', '1000B', '1.95kB', '2MB', '3.1GB', '4TB', '5PB', '6EB'] >>> print(mem_str(123913)) 121kB """ bytes = int(bytes) power = min(int(np.log(bytes) / np.log(1024)), 6) if bytes % 1024**power == 0: mem = str(bytes // 1024**power) else: mem = "{:.3g}".format(bytes / 1024.0**power) return "{}{}B".format(mem, ["", "k", "M", "G", "T", "P", "E"][power]) def hex_mantissa(x): """Return the hex mantissa for x. Used to ensure that parameters have an exact representation in floating point. Examples -------- >>> print(hex_mantissa(0.1)) 1.999999999999a >>> print(hex_mantissa(0.25)) 1. """ mantissa = float(x).hex()[2:].split("p")[0].rstrip("0") return mantissa
[docs] def get_smooth_transition(fs, durations, transitions, alphas=None): """Return a C(inf) smooth transition as a function of t. Smoothly transition from fs[0] to fs[1] to fs[2] etc. and hold these for times `ts[0]`, `ts[1]`, respectively starting from `t=0`. The transitions take time `dts[0]` etc. Arguments --------- fs : [float] List of `N` function values. durations : [float] List of `N-1` durations for each function values (the last value will be held indefinitely). transitions : [float] List of `N-1` transition durations. alphas : [float] List of `N-1` alpha values for each transition. """ fs = np.asarray(fs) if np.issubdtype(fs.dtype, np.integer): # Fix issue #9. If fs is an integer type, then vectorize might allocate an # integer output array. Here we make sure it is at least a float... but don't # convert complex arrays to floats by mistake! fs = fs.astype(float) N = len(fs) if alphas is None: alphas = [1.0] * (N - 1) assert N - 1 <= len(durations) assert N - 1 == len(alphas) durations = durations[: N - 1] assert N - 1 == len(transitions) ts = np.empty(2 * N - 2) ts[::2] = durations ts[1::2] = transitions ts = np.cumsum(ts) @np.vectorize(otypes=[fs.dtype]) def smooth_transition(t): if t <= ts[0]: # Special case for t < t_0 return fs[0] if t >= ts[-1]: # Special case for t > t_max return fs[-1] # Index of right time interval i1 = np.where(t < ts)[0][0] i0 = i1 - 1 if 0 == i1 % 2: # Hold interval ind = i1 // 2 return fs[ind] else: # Transition interval ind = i1 // 2 f0, f1 = fs[ind : ind + 2] s = step(t - ts[i0], transitions[ind], alpha=alphas[ind]) return f1 * s + f0 * (1 - s) return smooth_transition
[docs] def evolve_to( state, t, Evolver=EvolverABM, dt_t_scale=0.1, callback=None, plot_steps=100 ): """Evolve state to time t. Arguments --------- callback : None, function If provided, then call callback(state) every plot_steps steps. plot_steps : int Steps to take between plots. """ dt = dt_t_scale * state.t_scale t_max = t - state.t steps = int(np.ceil(t_max / dt)) dt = t_max / steps evolver = Evolver(state, dt=dt) if callback is not None: stepss = (plot_steps,) * (steps // plot_steps) + (steps % plot_steps,) if stepss[-1] == 1: stepss[-2] -= 1 stepss[-1] += 1 for _steps in stepss: evolver.evolve(_steps) callback(evolver.get_y()) else: evolver.evolve(steps) return evolver.y
[docs] def evolve( state=None, history=None, Evolver=EvolverABM, t_max=np.inf, dt_t_scale=0.1, steps=100, display=True, ): """Iterator that runs an evolver. Parameters ---------- state : IState State to evolve. history : [State] List of states. If provided, then use the last state to start. This is mutated. Evolver : IEvolver Evolver to use. t_max : float Maximum time to evolve to. (Default - evolve until interrupted.) dt_t_scale : float Size of time-step in units of state.t_scale. steps : int Steps to evolve between yielded states. display : bool If True, then display the current figure (using plt.gcf()). Example: for y in evolve(s): plt.clf() y.plot() """ if state is None: state = history[-1] else: try: len(state) history = state state = history[-1] except TypeError: history = [state] dt = dt_t_scale * state.t_scale if t_max < np.inf: steps_ = int(np.ceil(t_max / dt)) dt = t_max / steps_ steps_ = steps evolver = Evolver(state, dt=dt) yield evolver.y with NoInterrupt(ignore=True) as interrupted: while evolver.t < t_max and not interrupted: if t_max < np.inf: steps_ = min(steps, int((t_max - evolver.t) / dt)) evolver.evolve(steps_) history.append(evolver.get_y()) yield history[-1] if steps_ < steps: break if display: from matplotlib import pyplot as plt import IPython.display IPython.display.display(plt.gcf()) IPython.display.clear_output(wait=True)
[docs] def evolves( states=None, histories=None, Evolver=EvolverABM, t_max=np.inf, dt_t_scale=0.1, steps=100, fig=None, display=True, ): """Iterator that runs an evolver like evolve() but with multiple states. Parameters ---------- states : IState List of states to evolve. The first of these is used as the reference state for units such as t_scale etc. histories : [[State]] List of histories: each is a list of states. If provided, then use the last states to start. These lists are mutated. Evolver : IEvolver Evolver to use. t_max : float Maximum time to evolve to. (Default - evolve until interrupted.) dt_t_scale : float Size of time-step in units of state.t_scale. steps : int Steps to evolve between yielded states. display : bool If True, then display the current figure (using fig or plt.gcf()). """ if states is None: states = [_history[-1] for _history in histories] else: histories = [[_state] for _state in states] state = states[0] dt = dt_t_scale * state.t_scale if t_max < np.inf: steps_ = int(np.ceil(t_max / dt)) dt = t_max / steps_ steps_ = steps evolvers = [Evolver(_s, dt=dt) for _s in states] evolver = evolvers[0] yield [_e.y for _e in evolvers] with NoInterrupt(ignore=True) as interrupted: while evolver.t < t_max and not interrupted: if t_max < np.inf: steps_ = min(steps, int((t_max - evolver.t) / dt)) [_e.evolve(steps_) for _e in evolvers] [_h.append(_e.get_y()) for _e, _h in zip(evolvers, histories)] yield [_h[-1] for _h in histories] if steps_ < steps: break if display: from matplotlib import pyplot as plt import IPython.display if fig is None: fig = plt.gcf() IPython.display.display(fig) IPython.display.clear_output(wait=True)
[docs] class Frames: """Represents a series of frames (arrays) on disk for checkpointing and making movies. Each frame is stored in a file with a key that is usually the time at which the frame is valid. The main interface is through item access, i.e.:: state.set_data(frames[key]) with frames as frames: frames[key] = state.get_data() Checkpoints should be created explicitly:: with frames as frames: frames.checkpoint(key, state.get_data()) These will be treated as regular frames, but will be deleted when a new checkpoint is created. Context ======= Frames instances should generally be used as a context if writing to disk. This will suspend immediate mode until the end of the context, improving performance. Key Conversions =============== To ensure safe comparisons between frame keys, we convert to and from an ikey with the methods `key_to_ikey()` and `ikey_to_key()`. These should perform appropriate manipulations like rounding so that equality comparison between keys is meaningful. (Direct comparison of floating point values is dangerous since round-off error could cause a key failure between keys that are practically the same, but obtained by slightly different orders of operations.) key : This is what the user supplies as a key. In the default implementation this is either a floating point number or a tuple of floating point numbers. ikey : This is the internal representation of the key, use for comparison, indexing, etc. In the default implementation, we use the Decimal() class to truncate and provide an exact representation of the floating point number. Finally, the ikey needs to be converted to a string for use in the filenames. These conversions are done by the ikey_to_str() and str_to_ikey() methods. All four conversion methods should be redefined if a different type of key is used. The default implementation produces filenames such as:: frame_0.1000_image_0.0500.npy which would represent a frame evolved after 0.1 time units and then imaged after an additional 0.05 time units of expansion. Attributes ---------- data_dir : str The frames will be stored in this directory. mode : str containing 'r', 'w', 'm' Read ('r'), write ('w'), or in-memory ('m'). Data not written to disk unless 'w' in mode. Setting a frame without 'w' issues a warning that the data will not be stored on disk. This can be suppressed by setting the mode to 'm' which indicates in-memory mode. Setting `mode='m'` will not access the disk at all (even read access will be disabled). The default `mode=''` is 'w' in a context and 'r' outside a context. prefix : str Frame filenames start with this. sep : str When the key is iterable (as defined by `self.isiterable()`), then keys are joined by this separator. The default is `"_image_"`. immediate : bool, None If `True`, then frames will be saved to disk upon assignment, otherwise they will be saved upon `flush()` (also called at the end of a context.) The default (`None`) is False if used in a context, but True otherwise. checkpoints_to_retain : int When saving a checkpoint, keep this many previous checkpoints and delete the rest. mem_limit_bytes : None (default) or int If not `None`, limit in bytes imposed by `self.limit_memory()` on the total size the frames may take up in memory. If `self.immediate is None` (default), `limit_memory` is called upon exiting a context. If `self.immediate`, `limit_memory` is automatically called whenever a frame is set or loaded. Otherwise, `limit_memory` must be invoked manually. decimal_precision : int This is a specialized argument for the default version of the class. Internally, we use the Decimal class for keys, normalized to this precision. If precision is lost (according to `np.allclose`, then a ValueError is raised. When loading data from a file, this may be increased to prevent loss of precision. Examples -------- Notes ----- * If `del` is called, the underlying file will be removed immediately, even if the immediate flag is False. * The list of `keys()` will only be updated at the start of a context, the first time it is called, or after `flush()`. * One can only set an item if it is not already existing. (Call del first if needed.) """ def __init__( self, data_dir, mode="", prefix="frame_", sep="_image_", checkpoint_prefix="check_", checkpoints_to_retain=1, immediate=None, mem_limit_bytes=None, decimal_precision=4, ):
[docs] self.data_dir = data_dir
[docs] self.prefix = prefix
[docs] self.checkpoint_prefix = checkpoint_prefix
[docs] self.checkpoints_to_retain = checkpoints_to_retain
[docs] self.sep = sep
self.mode = mode
[docs] self.immediate = immediate
[docs] self.mem_limit_bytes = mem_limit_bytes
[docs] self.decimal_precision = decimal_precision
[docs] self._data = {}
[docs] self._context = False
@property
[docs] def mode(self): """Return the mode, making context corrections for defaults.""" mode = self._mode if mode == "": mode = "w" if self._context else "r" return mode
@mode.setter def mode(self, mode): self._mode = mode.lower()
[docs] def __enter__(self): """Enter the context.""" if self._context: raise NotImplementedError("Nested contexts not supported.") self._context = True if set("wr").intersection(self.mode): self._get_data_dir() self.keys(update=True) return self
[docs] def __exit__(self, exc_type, exc_val, exc_tb): try: with NoInterrupt(ignore=False): if "w" in self.mode: if hasattr(self, "_issue_2"): import signal os.kill(os.getpid(), signal.SIGINT) self.flush() self.limit_memory() finally: self._context = False
###################################################################### # Customizations # # These methods perform various conversions between key, ikey, and # the filename str. Subclasses using keys that are not floats # should redefine all of these methods to be consistent.
[docs] def isiterable(self, key): """Return `True` if the key represents a list or tuple of subkeys.""" return isinstance(key, abc.Iterable) and not isinstance(key, string_types)
[docs] def Decimal(self, key, increase_precision=False): """Return `Decimal(key)` rounded to `self.decimal_precision`. Arguments --------- increase_precision : bool If `True`, then increase as needed `self.decimal_precision` to ensure accuracy of the keys, otherwise, raise `ValueError` if precision is lost. """ if isinstance(key, np.integer): key = int(key) elif isinstance(key, np.floating): key = float(key) t = decimal.Decimal(key) if increase_precision: self.decimal_precision = max(self.decimal_precision, -t.as_tuple().exponent) else: t = t.quantize(decimal.Decimal(10) ** (-self.decimal_precision)) if not np.allclose(float(t), float(key), rtol=1e-12, atol=1e-12): raise ValueError( "Precision lost when converting key: {}->{}".format( repr(key), repr(t) ) ) return t
[docs] def key_to_ikey(self, key): """Convert user-supplied key to internal format.""" if self.isiterable(key): ikey = tuple(self.Decimal(_k) for _k in key) else: ikey = (self.Decimal(key),) return ikey
[docs] def ikey_to_key(self, ikey): """Convert internal key format to user format.""" if ikey is None: return () elif len(ikey) == 1: return float(ikey[0]) else: return tuple(map(float, ikey))
[docs] def ikey_to_str(self, ikey): """Convert internal key to string for filename.""" fmt = "{}".format return self.sep.join(map(fmt, ikey))
[docs] def str_to_ikey(self, key_string): """Convert string to internal key, increasing `decimal_precision` if required. """ ikey = tuple( self.Decimal(_k, increase_precision=True) for _k in key_string.split(self.sep) ) return ikey
# End of customizable methods ######################################################################
[docs] def __contains__(self, key): return self.key_to_ikey(key) in self.ikeys()
[docs] def ikeys(self, update=False): """Return a list of available ikeys. Arguments --------- update : bool If True, then reset the list of previously loaded files, and reread data from disk. If False, then only previously detected frames and computed frames will be seen - any additional frames (i.e. saved by another process) will not be detected. """ if update and "_files" in self.__dict__: del self.__dict__["_files"] return sorted(set(self._files).union(self._data))
[docs] def keys(self, update=False): """Return a list of available keys. Arguments --------- update : bool If True, then reset the list of previously loaded files, and reread data from disk. If False, then only previously detected frames and computed frames will be seen - any additional frames (i.e. saved by another process) will not be detected. """ return list(map(self.ikey_to_key, self.ikeys(update=update)))
[docs] def __getitem__(self, key): ikey = self.key_to_ikey(key) if ikey not in self._data and ikey in self._files: filename = self._files[ikey] if not os.path.exists(filename): raise LookupError( r"\n".join( [ "File {} for key={} has gone missing!", "Do you need to run git annex?", "", " git annex get {}", ] ).format(filename, key, os.path.dirname(filename)) ) value = np.load(filename) self._data[ikey] = value value = self._data[ikey] if self.immediate: self.limit_memory() return value
[docs] def checkpoint(self, key, value): """Save data to file, but as a checkpoint.""" self.__setitem__(key, value, checkpoint=True)
[docs] def _is_checkpoint(self, ikey): """Return True if key is saved as a checkpoint.""" return ikey in self._files and ( os.path.basename(self._files[ikey]).startswith(self.checkpoint_prefix) )
[docs] def _convert_checkpoint(self, ikey): """Convert a saved checkpoint to a real frame. Return True if if checkpoint file was converted. """ if not self._is_checkpoint(ikey=ikey): return False filename = self._get_filename(ikey=ikey, data_dir=self.data_dir, checkpoint=False) os.rename(self._files[ikey], filename) self._files[ikey] = filename return True
[docs] def __setitem__(self, key, value, checkpoint=False): ikey = self.key_to_ikey(key) mode = self.mode if not set("wm").intersection(mode): warnings.warn(f"Setting Frame[{key}] in {mode=} will not get saved to disk!") if self._convert_checkpoint(ikey=ikey): # Special case of key already being a checkpoint file: assert ikey in self._data and ikey in self._files elif ikey in self._data: raise LookupError( "Data for key={} already set. (Call del first.)".format(key) ) elif ikey in self._files: raise LookupError( "File for key={} already exists. (Call del first.)".format(key) ) else: self._data[ikey] = np.asarray(value) # REV: Make sure all cases here are tested! immediate = self.immediate or (self.immediate is None and not self._context) if (immediate or checkpoint): if "w" in mode: self._save(ikey, value, checkpoint=checkpoint) self.limit_memory()
[docs] def __delitem__(self, key, ikey=None): if ikey is None: ikey = self.key_to_ikey(key) if ikey not in self._files and ikey not in self._data: raise LookupError(key) filename = self._files.get( ikey, self._get_filename(ikey=ikey, data_dir=self.data_dir) ) if filename and os.path.exists(filename): os.remove(filename) if ikey in self._files: del self._files[ikey] if ikey in self._data: del self._data[ikey]
[docs] def flush(self): """Write current data to disk.""" mode = self.mode if "w" not in self.mode: raise ValueError(f"Cannot flush data in {mode=}") data_dir = self._get_data_dir() self._remove_checkpoints() for ikey in self._data: if ikey in self._files and os.path.exists(self._files[ikey]): continue self._save(ikey, value=self._data[ikey], data_dir=data_dir) self.keys(update=True)
[docs] def get_mem_size(self): """Return the size of the frames stored in memory in bytes. Note that this only calls `sys.getsizeof()` on each value of `self._data`; the actual size of a Frames instance will be slightly bigger. """ out = 0 seen = [] for val in self._data.values(): if id(val) in seen: # Don't double count frames that are the same object continue seen.append(id(val)) out += sys.getsizeof(val) return out
[docs] def limit_memory(self, update_files=True): """Drop the oldest frames from memory to fit under the memory limit.""" if self.mem_limit_bytes is None: return while self.get_mem_size() > self.mem_limit_bytes: if update_files and "_files" in self.__dict__: # Refresh list of files on disk del self.__dict__["_files"] # Since python 3.7, dict_keys are ordered according to when they # were added, so we just need the first key to get the oldest frame. oldest_ikey = next(iter(self._data.keys())) if oldest_ikey not in self._files: key = self.key_to_ikey(oldest_ikey) raise MemoryError( f"`mem_limit_bytes` (={self.mem_limit_bytes}) exceeded, and Frame[{key}] cannot be dropped because it is not saved on disk" ) self._data.pop(oldest_ikey)
###################################################################### # Private methods @property
[docs] def _prefixes(self): """Return a tuple of the possible filename prefixes.""" return (self.prefix, self.checkpoint_prefix)
[docs] def _filename_to_ikey(self, filename): """Extract the time key from the filename.""" if not filename: return () framename = os.path.basename(filename) key_string = None for prefix in self._prefixes: if framename.startswith(prefix): key_string = framename[len(prefix) : -4] break if key_string is None: raise ValueError(f"Unknown file prefix for frame {framename}") try: ikey = self.str_to_ikey(key_string=key_string) except (decimal.InvalidOperation, ValueError): ikey = () warnings.warn( "Skipping malformed frame name {} (should be {}).".format( framename, prefix + "###.####.npy" ) ) return ikey
@property
[docs] def _files(self): mode = self.mode if "m" in mode: return {} # Only in-memory use, so no files. if "_files" not in self.__dict__: all_files = [] for _prefix in self._prefixes: all_files.extend( glob.glob(os.path.join(self._get_data_dir(), _prefix + "*.npy")) ) self.__dict__["_files"] = { self._filename_to_ikey(filename=_f): _f for _f in all_files } for _ikey in list(self._data): if _ikey not in self._files: del self[_ikey] return self.__dict__["_files"]
[docs] def _get_data_dir(self): """Return `data_dir`, checking that it exists, is a directory, and making it if needed and 'w' in mode. """ mode = self.mode if mode == "m": raise ValueError(f"No get data_dir in {mode=}.") data_dir = self.data_dir if not os.path.exists(data_dir): if "w" in mode: os.makedirs(data_dir) else: raise ValueError( f"Cannot read ({mode=}) since directory {data_dir=} does not exist." ) elif not os.path.isdir(data_dir): raise IOError(f"Specified {data_dir=} is not a directory.") return data_dir
[docs] def get_filename(self, key, checkpoint=False): """Return the filename associated with key.""" return self._get_filename( ikey=self.key_to_ikey(key=key), data_dir=self.data_dir, checkpoint=checkpoint, )
[docs] def _get_filename(self, ikey, data_dir, checkpoint=False): """Return the filename associated with key.""" if checkpoint: prefix = self.checkpoint_prefix else: prefix = self.prefix name = prefix + self.ikey_to_str(ikey=ikey) return os.path.join(data_dir, name + ".npy")
[docs] def _save(self, ikey, value, data_dir=None, checkpoint=False): if data_dir is None: data_dir = self._get_data_dir() filename = self._get_filename(ikey=ikey, data_dir=data_dir, checkpoint=checkpoint) if os.path.exists(filename): raise IOError("File {} already exists!".format(filename)) if checkpoint: # Remove old checkpoints self._remove_checkpoints(checkpoints_to_retain=self.checkpoints_to_retain - 1) np.save(filename, value) self._files[ikey] = filename
[docs] def _remove_checkpoints(self, checkpoints_to_retain=None): """Remove old checkpoints in key order.""" if checkpoints_to_retain is None: checkpoints_to_retain = self.checkpoints_to_retain checkpoints = sorted([ikey for ikey in self._files if self._is_checkpoint(ikey)]) for ikey in checkpoints[: len(checkpoints) - checkpoints_to_retain]: self.__delitem__(key=None, ikey=ikey)
@implementer(IExperiment) class ExperimentBase(Archivable): """Base for Experiment classes. Inherit from this class to provide an interface to the problem. It's main role should be to accept experimentally relevant parameters, and then produce an appropriate initial state that representing the experimental protocol. See `gpe.utils.ExperimentExample` for a demonstration of how to use this class. Note: All times are expressed in terms of `t_unit` such that `t_ = t/t_unit`. We try to consistently use the name `t_` for such dimensionless quantities except in class variables which are assumed to be in the specified `t_unit`. 1. Also provides a mechanism for recording time-dependent parameters. These should be provided through a set of methods names `*_t_()` which take the normalized time `t_` as an argument and return the time-dependent value of this parameter. For plotting purposes an accompanying method `*_info()` should be defined which returns the corresponding unit value and a label. 2. Internal methods should use the `get(param, t_)` method which will delegate to the appropriate time-dependent function if it exists (or fall back to the basic parameter access). Simulations and Imaging ======================= The idea of an "Experiment" is some sort of simulation run defined by a set of parameters (attributes of this class) that is evolved through a set of `image_ts_` under a set of "normal" experimental conditions. These states would be what is observed in "in situ imaging". Typically, however, from these states, one evolves for an additional time `t__image` without any interactions or traps to allow the clouds to "expand", after which an "expansion image" is taken, usually resolving better details like vortices and domain walls. """ State = NotImplemented # State class used t_unit = 1.0 # Conversion factor for time units t_name = "" # Optional name image_ts_ = () # Times at which experiment starts imaging. t__image = 0 # Length of expansion imaging. _initializing = False @property def t__final(self): """Return `t__final = t_final/t_unit`.""" if not self._initializing: # We do some inspection when initializing, so don't warn then. warnings.warn( "t__final is deprecated: please use image_ts_ or state.t_final", DeprecationWarning, stacklevel=2, ) if len(self.image_ts_) > 0: return max(self.image_ts_) return None @t__final.setter def t__final(self, t__final): warnings.warn( "experiment.t__final is deprecated: please use state.t_final", DeprecationWarning, stacklevel=2, ) self.image_ts_ = tuple(sorted(set(self.image_ts_).union({t__final}))) # These control dir_name. max_key_length = 3 # Limit key lengths (if not directory_per_key) sort_keys = True # Sort keys for unique dir_name (issue #4). # only set False to emulate old behavior. key_order = () # If provided, these keys come first. directory_per_key = True # Use a directory for each key (shorter filenames) # For debugging purposes we keep track of which keys are used so that # spelling mistakes can be checked for (see `_unused_keys()`). Here we # pre-populate this set with some keys that are only used on saving etc. _used_keys = ["State", "t_name", "t_unit"] # We also have a set of special keys that should not be used in # dir_name. _special_keys = set( [ "image_ts_", "t__image", "max_key_length", "directory_per_key", "sort_keys", "key_order", ] ) _deprecated_keys = set(["t__final"]) # Deprecated keys are only allowed to be defined in the class with this # full name (currently `'gpe.utils.ExperimentBase'`) _base_class_name = __module__ + "." + __qualname__ def __init__(self, _local_dict=None, **kw): """Constructor. The constructor takes a dictionary of the local variables passed to the subclass. These will be assigned to variables of the same name in `__dict__` and stored to generate the `dir_name` where data will be stored. Names starting with `_` will be ignored. The `kw` argument is provided to allow subclasses to pass in additional parameters. Note: We generally recommend that users DO NOT overload the constructor for several reasons. Although one can use this approach to define additional parameters, this use is generally discouraged for the following reasons: 1. These will override any parameters defined at the class level - EVEN IF OVERRIDDEN IN SUBCLASSES. If you want subclasses to override these parameters, they MUST do so in their `__init__()` method. 2. These parameters will ALWAYS be included in the directory name, even if the default values are used. This behavior can be achived anyway by simply passing the default value to the default constructor. If you proceed with this approach, use the following model:: def __init__(self, a=0.0, **_kw): super().__init__(locals(), **_kw) Make sure you pass through `_kw` with an underscore so it does not get set as an attribute. The following approach of using `locals()` allows you to forgo repeating the keyword arguments you specify in the signature. Do not do any initialization here - only set parameters. Initialization should be performed in the `init()` method which will be called by `ExperimentBase.__init__()`. """ self._initializing = True # This is a list of all parameters known to the class. For debugging # purposes, we do not allow the user to set additional parameters in # the constrctor, which are usually "spelling" mistakes (see issue 1). # This is a little different from self._keys which are all the "active" # keys that are used in the file name etc. Only active keys can be set # after construction. _used_keys = set() for kls in inspect.getmro(self.__class__): _used_keys.update(getattr(kls, "_used_keys", [])) self._known_keys = sorted( [ _key for _key, _value in inspect.getmembers( self, lambda a: not (inspect.isroutine(a)) ) if not _key.startswith("_") and _key not in self._deprecated_keys ] ) # Set this *after* doing the inspect stuff because it uses all the keys! self._used_keys = _used_keys if _local_dict is None: args = dict(kw) else: args = dict(_local_dict, **kw) args = {_k: args[_k] for _k in args if (not _k.startswith("_") and _k != "self")} keys = set(args).union(self._special_keys) # Check if any child classes re-define a deprecated key family = inspect.getmro(self.__class__) family_names = [kls.__module__ + "." + kls.__qualname__ for kls in family] base_class = family[family_names.index(self._base_class_name)] children = [c for c in family if issubclass(c, base_class) and c != base_class] for child in children: redefined_deprecated_keys = set(child.__dict__).intersection( self._deprecated_keys ) if redefined_deprecated_keys: warnings.warn( ("{} redefines deprecated key(s) {}.").format( child.__name__, sorted(redefined_deprecated_keys), ), DeprecationWarning, stacklevel=2, ) # Check if any arguments are deprecated deprecated_keys = set(args).intersection(self._deprecated_keys) if deprecated_keys: warnings.warn( ("{} got deprecated keyword argument(s) {}.").format( self.__class__.__name__, sorted(deprecated_keys), ), DeprecationWarning, stacklevel=2, ) # Check that all arguments are indeed attributes unknown_keys = set(args).difference( set(self._known_keys).union(self._deprecated_keys) ) if unknown_keys: raise ValueError( ( "{} got unexpected keyword argument(s) {}.\n" + "(Known keys: {})" ).format(self.__class__.__name__, sorted(unknown_keys), self._known_keys) ) self._keys = sorted(keys) for _k in args: setattr(self, _k, args[_k]) # Placing this before setting _initializing=False allows # attributes to be set by self.init() self.init() self._initializing = False def _unused_keys(self): """Return set of unused keys (parameters). This is for debugging purposes to make sure that parameters are not misspelled.""" used_keys = set.union(self._used_keys, self._special_keys) return set(self._known_keys).difference(used_keys) def items(self): """Provides support for Archivable.""" return [(_k, getattr(self, _k)) for _k in self._keys] def get_diff(self, experiment): """Return a list of differing keys (excludes _special_keys).""" if experiment is self: return [] assert set(self._known_keys) == set(experiment._known_keys) set_keys = set(self._keys).union(experiment._keys).difference(self._special_keys) return [ key for key in set_keys if not getattr(self, key) == getattr(experiment, key) ] def init(self): """Overload this to perform any initial computations.""" # Save and restore _used_keys so inspect call does used # everything. _used_keys = set(self._used_keys) members = inspect.getmembers(self, predicate=callable) self._used_keys = _used_keys self.time_dependent_parameters = { _name[:-3]: _meth for (_name, _meth) in members if _name.endswith("_t_") } for name in self.time_dependent_parameters: # method_name = name + "_t_" info_name = name + "_info" if not hasattr(self, info_name): setattr(self, info_name, (1.0, name)) def __getattribute__(self, key): val = super(ExperimentBase, self).__getattribute__(key) if not key.startswith("_") and not (inspect.isroutine(val)): self.__dict__.setdefault("_used_keys", set()).add(key) return val def __setattr__(self, key, value): """Prohibit setting of non-key attributes.""" if key.startswith("_") or self._initializing or key in self._keys: self.__dict__[key] = value else: raise AttributeError("Cannot set {}: only {}".format(key, self._keys)) @property def dir_name(self): """Return the name of the directory in which to store checkpoint data. This name should encode all of the parameter values. The default version here uses a nested structure with the class name and then the parameters: `<Experiment>/<param1=...,param2=...,...>/` To ensure that the directories are unique, we first use the list param_order, then sort alphabetically. """ keys = set(self._keys).difference(self._special_keys) sorted_keys = [_key for _key in self.key_order if _key in keys] new_keys = keys.difference(sorted_keys) if self.sort_keys: new_keys = sorted(new_keys) sorted_keys.extend(new_keys) # These should be unique. keys = sorted_keys if not self.directory_per_key: short_keys = [] for k in keys: if len(k) > self.max_key_length: k = k[: self.max_key_length].strip(" -_") if k in short_keys: suffixes = ( "0123456780" + "abcdefghijklmnopqrstuvwxyz" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" ) short_k = k[:-1] for _s in suffixes: if short_k + _s not in short_keys: k = short_k + _s if k in short_keys: raise ValueError( "Could not find unique {}-letter name for parameter {}.".format( self.max_key_length, k ) + " (Collides with {})".format(short_keys) ) short_keys.append(k) else: short_keys = keys key_pairs = [ "{}={}".format(_k, getattr(self, _key)).replace( " ", "" ) # Remove inner spaces. for _k, _key in zip(short_keys, keys) ] dirs = [self.__class__.__name__] if self.directory_per_key: dirs.extend(key_pairs) else: dirs.append(",".join(key_pairs)) return os.path.join(*dirs) def copy(self): """Return a copy of the experiment.""" return self.__class__(**dict(self.items())) ###################################################################### # Parameter access def get(self, name, t_): """Return the value of the parameter `name` at time `t_`. If the method `name_t_()` exists, it is called, otherwise the attribute `name` is used. """ _meth = self.time_dependent_parameters.get(name) if _meth is None: return getattr(self, name) else: return _meth(t_) @implementer(IStateForABMEvolvers) class StateExample(StateMixin, ArrayStateMixin): """Minimial State class that implements the ODE `dy_dt = a*y + b`.""" @property def t_scale(self): """Return the natural time-scale for evolvers. This is used by simulations to determine the evolver timestep. For quantum problems, this is usually something like `hbar/E_max` where `E_max` is the largest energy-scale appearing (for example, `E_max = hbar**2*k_max**2/2/m` where `k_max = np.pi * Nx/Lx` is the largest momentum representable in the system. """ return 1.0 / self.a def __init__(self, a, b, experiment=None): self.a = a self.b = b self.experiment = experiment self.data = np.ones(1, dtype=float) def compute_dy_dt(self, dy): if self.experiment is not None: # Simulates an "imaging procedure". If t is less than t_final, # we use the specified value of b, otherwise for imaging we set # b=0. if self.t <= self.t_final: b = self.b else: b = 0.0 dy.set_data(self.a * self.get_data() + b) return dy def answer(self, t=None): """Return the analytic solution starting from `y(0) = 1`.""" if t is None: t = self.t if self.experiment is not None: t0 = min(t, self.t_final) else: t0 = t # y1 = (y0 + b/a)*exp(a*(t-t0)) - b/a a, b = self.a, self.b y0 = 1.0 y0 = (y0 + b / a) * np.exp(a * (t0 - 0.0)) - b / a # Now b = 0 y = y0 * np.exp(a * (t - t0)) return y def plot(self, fig): from matplotlib import pyplot as plt plt.plot(self.t, self[0], "o") return fig class ExperimentExample(ExperimentBase): """Skeleton class demonstrating how to use ExperimentBase. Fixed experimental parameters that won't change can be specified as class members. Specifying the parameters here is recommended and has several advantages over defining an `__init__()` method: 1. They will not be part of the filename unless manually overridden by user and passed the to constructor. 2. By inheriting from this class, the default parameter values can be overridden, essentially encapsulated in the new class name. 3. If all parameters are specified as class variables, there is no need to write the constructor `__init__()`. `ExperimentBase.__init__()` will allow the user to override parameters if needed. 4. If directory names are desired for these parameters, they can be explicitly passed to the default constructor and then the appropriate directories will be created. Note: If you need to perform any calculations for initialization, do this in the `init()` method which will be called by the default constructor `ExperimentBase.__init__()`. """ ###################################################################### # Attributes required by IExperiment t_unit = 1.2 # Specifies units for experiment times. image_ts_ = (10.0,) # Times for imaging (in units of t_unit). t__image = 1.5 # Imaging for 1.5 units # End of attributes required by IExperiment ###################################################################### # Experimental parameters amplitude = 1.2 offset = 3.4 # There is a debugging feature which keeps track of the keys (parameters) # used during the course of execution to make sure that parameters are not # mispelled (see issue 1). Sometimes parameters will only conditionally # used and you want to manually exclude it from the checks. You can do # this by adding it to your class's _used_keys. _used_keys = [] ###################################################################### # Methods required by IExperiment def init(self): """Initializes the experiment. All initialization should be done here, not in a constructor. """ # For example, one might translate the experimental parameters into # appropriate units, or perform conversions so that the State object # can be coded in the most natural set of units, but the experiment can # be specified in the most physically relevant units. # The use of a state_args dict would, for example, allow subclasses to # share `get_state()` and `get_initial_state()` while adding additional # parameters. However, overloading `get_state()` might be simpler. self.state_args = dict(a=self.amplitude, b=self.offset) # Don't forget to call the parent init(). You might want to do this # first if you need to use the results of the parent's init() method, # or after if you need to modify parameters before the parent's init() # is called. (`ExperimentBase.init()` does not do anything, but we # demonstrate here for good measure.) super().init() def get_state(self, initialize=True): """Quickly return a valid `State` object.""" return StateExample(experiment=self, **self.state_args) def get_initial_state(self, _E_tol=1e-12, _psi_tol=1e-12): """Return the valid `t=0` state to initialize the simulations.""" # It might be convenient while debugging to play with minimization # parameters here, but when running and checkpointing, the parameters # that work should be fixed here and this should be called only without # arguments. state = self.get_state() # Here is how you might go about preparing the initial state. Note: # rely on `get_state()` to set all parameters - just set the data after # minimizing. Do not use the state returned by the minimizer # # from minimize import MinimizeState # minimizer = MinimizeState(state, fix_N=True) # psi0 = minimizer.minimize(E_tol=_E_tol, psi_tol=_psi_tol).get_psi() # state = self.get_state() # state.set_psi(psi0) return state def get_initialized_state(self, state): """Return a valid state initialized from `state`. This is used in chained simulations where a specified state of one simulation is used to initialize a state for further use. For example, for expansion.""" state_ = self.get_state() state_.set_psi(state.get_psi()) return state_ # End of methods required by IExperiment ###################################################################### class SimulationBase: """Manages a simulation, including checkpointing and restarts. A simulation is a collection of runs associated with an experiment with an associated collection of frames corresponding to different parameters. These could be different times during a long-running simulation, or different parameter values for a scan. <incomplete> """ class Simulation: """Manages a simulation, including checkpointing and restarts. This class requires a State object which can be evolved forward (and backward) in time. It will then initialize the state and evolve to the desired time with the checkpoint requirements. Persistence is managed through checkpoints on disk in the directory `self.dir_name == os.path.join(self.data_dir, self.experiment.dir_name)`. This directory will contain the following file: <dir_name>/experiment.py <dir_name>/run_*.txt The first is an archive of the Experiment class which can be used to initialize the experiment. The second are logs produced when `run()` is called. The directory may also contain data for the checkpoints. The format is up to the Experiment class, but will generally be of the form: <dir_name>/frame_*.npy The `self.frames` object is an instance of `Frames` with keys that map to these frames. They can be length one or two tuples. For example: ikey = (Decimal('0.2000'),) -> frame_0.2000.npy ikey = (Decimal('0.2000'), Decimal('1.000')) -> frame_0.2000_image_1.000.npy See `Frames` for a discussion of the distinction between `key` and `ikey` (internal key). These simulations will be run, and checkpoints made at a set of times (controlled by the Simulation class). These checkpoints are stored as data frames in a directory specified by the name of the experiment and the parameter values that are different from the default values. Additionally, one might make images of the state at various times. When evolving, this is controlled by the `t__image` parameter which is special (described below). For example, consider a class `Experiment1` with parameter `a` with different values `0` and `1.1`, specifying that imaging should take place at `image_ts_ = (0.2, ...)` after expanding for `t__image = 0.05`, we might run a `Simulation` with `dt_ = 0.1` ending up with following file-structure: Experiment1/a=0/frame_0.1000.npy /frame_0.2000.npy /frame_0.2000_image_0.0500.npy /frame_0.3000.npy ... Experiment1/a=1.1/frame_0.1000.npy /frame_0.2000.npy /frame_0.2000_image_0.0500.npy /frame_0.3000.npy ... These data are generated in a "run" which is a sequence of times that a state is run through. Typically we start with a single run which saves intermediate checkpoints and checkpoints at all the times `experiment.image_ts_` where we will later perform imaging. Then, a series of runs are performed from these saved `experiment.image_ts_` up to time `t__final + t__image`. By setting `state.t_final`, the state can change the potentials to perform imaging for times `t > state.t_final`. Parameters ---------- experiment : Experiment, None Experiment instance. Can be None if `dir_name` is specified and contains an executable `experiment.py` file. dir_name : str, None Location of data. Can be `None` if an `experiment` is provided. dt_ : float Time interval between frames (in units of `experiment.t_unit`) dt__image : float, None Time interval between frames (in units of `experiment.t_unit`) for imaging (default is to use dt_) checkpoint_dt_ : float, None Checkpoint interval (in units of `experiment.t_unit`). If `None`, then `dt_` is used. During a simulation, checkpoints are made with a minimum step of this interval -- the actual `dt_` might be larger depending on the underlying step size (see `dt_t_scale`). When a new checkpoint is made, previous checkpoints are removed according to the value of `checkpoints_to_retain`. checkpoints_to_retain : int Number of checkpoints to retain. max_t_ : float Maximum time for simulation (in units of `experiment.t_unit`). Can be `None` if the experiment lists times `experiment.image_ts_`. If both are provided, then this flag will be used to limit execution (but a warning will be emitted). If set, this will also set the upper limit. extends : (Simulation, t_), None If provided, then this simulation will call `self.experiment.get_initialized_state(simulation.get_state(t))` in order to extend the previous simulation. This is useful when expanding for example. data_dir : str Where to store the results. checkpoint : bool If `True`, then checkpoint the results to disk (along with information about the run as returned by `record_computer_state()`). image_ts_ : [float] List of times to perform expansion imaging at. After the simulation is run to these times, a separate simulation is started which will do the expansion for imaging. Deprecated - use Experiment.image_ts_ The expansion is only done when the run_images() allow_negative_dt : bool If True, then allow the evolver to evolve backwards from the nearest state when computing. mem_limit_bytes : None (default) or int If not `None`, the maximum space in bytes that data stored in `self.frames` may take up in memory. The following parameters are provided in case they are not specified by the experiment, but the experiment-defined values will be preferred: dt_t_scale : float Timestep to use for evolver (in units of `state.t_scale`) if not provided by the attribute `experiment.dt_t_scale`. Evolver : IEvolver Which time evolver to use if not provided as an attribute `experiment.Evolver`. """ frame_prefix = "frame_" def __init__( self, experiment=None, dir_name=None, dt_=1, dt__image=None, max_t_=None, extends=None, dt_t_scale=0.1, data_dir=_DATA_DIR, Evolver=EvolverABM, logging_level=logging.INFO, checkpoint=True, checkpoint_dt_=None, checkpoints_to_retain=1, image_ts_=None, allow_negative_dt=False, mem_limit_bytes=None, ): if experiment is None: if dir_name is None: raise ValueError("Must provide either and experiment or a dir_name") else: experiment_file = os.path.join(dir_name, "experiment.py") if not os.path.exists(experiment_file): raise ValueError( "experiment not provided and couldn't be loaded from {}.".format( experiment_file ) ) d = {} with open(experiment_file) as f: exec(f.read(), d) experiment = d["experiment"] assert dir_name == os.path.join(data_dir, experiment.dir_name) self._experiment = experiment self._extends = extends self._dir_name = dir_name self.allow_negative_dt = allow_negative_dt self.data_dir = data_dir self.max_t_ = max_t_ self.dt_ = dt_ self.dt__image = dt__image self.checkpoint = checkpoint self.checkpoint_dt_ = checkpoint_dt_ self.checkpoints_to_retain = checkpoints_to_retain self.logger = logging.Logger(self.experiment.__class__.__name__) self.logger.setLevel(logging_level) self.logger.addHandler(logging.StreamHandler()) self.evolve_times = [] self.mem_limit_bytes = mem_limit_bytes if image_ts_ is not None: warnings.warn( "image_ts_ is deprecated: please use experiment.image_ts_", DeprecationWarning, stacklevel=2, ) self.image_ts_ = image_ts_ self._frames = None self._Evolver = Evolver self._dt_t_scale = dt_t_scale @property def experiment(self): if self._experiment is None: self.initialize() return self._experiment @property def dir_name(self): if self._dir_name is None: self.initialize(create_dir=False) return self._dir_name def info(self, msg): self.logger.info(msg) def warning(self, msg): self.logger.warning(msg) def error(self, msg): self.logger.error(msg) @property def Evolver(self): return getattr(self.experiment, "Evolver", self._Evolver) @property def dt_t_scale(self): return getattr(self.experiment, "dt_t_scale", self._dt_t_scale) @contextlib.contextmanager def msg(self, msg): self.info(msg + "...") tic = time.time() try: yield except Exception: self.error(msg + ". Failed!") raise toc = time.time() - tic self.info(msg + ". Done. ({:.2g}s)".format(toc)) def initialize(self, create_dir=True): """Initialize the simulation object, and perform sanity checks.""" write_to_disk = create_dir and self.checkpoint # Check for existing data or make sure dir_name exists. if self._experiment is not None: self._dir_name = os.path.join(self.data_dir, self.experiment.dir_name) if os.path.exists(self.dir_name): self.warning("Existing simulation directory found: {}".format(self.dir_name)) with self.get_frames(tmp=not write_to_disk) as frames: self.info("The following frames exist: {}".format(frames.ikeys())) elif write_to_disk: with self.msg("Creating simulation directory {}".format(self.dir_name)): os.makedirs(self.dir_name) experiment_file = os.path.join(self.dir_name, "experiment.py") if not os.path.exists(experiment_file): if self._experiment is None: raise ValueError( "No experiment provided or experiment_file={}".format(experiment_file) ) if write_to_disk: # Archive the experiment object with self.msg("Archiving experiment.py"): archive = Archive(scoped=False) archive.insert(experiment=self.experiment) with open(experiment_file, "w") as file: file.write(str(archive)) # Now load the experiment object to make sure that the archival was # successful if os.path.exists(experiment_file): with self.msg("Loading experiment from experiment.py"): d = {} with open(experiment_file) as f: exec(f.read(), d) experiment = d["experiment"] assert self.dir_name == os.path.join(self.data_dir, experiment.dir_name) if self._experiment is None: self._experiment = experiment else: # Check that experiments are consistent different_keys = experiment.get_diff(self._experiment) if different_keys: raise ValueError( f"{experiment_file=} inconsistent with self._experiment: " + f"{different_keys=}" ) else: # This should only happen if not write_to_disk assert not write_to_disk and self._experiment is not None experiment = self._experiment # Check that the state can be created and saved with self.msg("Creating test state"): state = self.experiment.get_state(initialize=False) if write_to_disk: tmp_prefix = "_tmp" + str(os.getpid()) + self.frame_prefix args = dict( data_dir=self.dir_name, prefix=tmp_prefix, checkpoints_to_retain=self.checkpoints_to_retain, ) with self.msg("Saving test state"): with Frames(mode="w", **args) as frames: frames[0.0] = state.get_data() statinfo = os.stat(frames.get_filename(0.0)) frame_size = statinfo.st_size with self.msg("Loading test state"): with Frames(**args) as frames: assert np.allclose(state.get_data(), frames[0.0]) del frames[0.0] self.frame_size_MB = frame_size / 1024.0**2 self.info("Frame size = {}".format(mem_str(frame_size))) # Compute number of frames to store and file size ts_ = self.get_ts_() Nframes = len(ts_) if write_to_disk: self.info( "All {} frames will take {} of space on disk.".format( Nframes, mem_str(Nframes * frame_size) ) ) def get_hostname(self): """Return the hostname.""" try: return subprocess.check_output(["hostname"]) except Exception: return b"Unknown" def get_cpuinfo(self): """Return info about the CPU.""" # For Linux try: return subprocess.check_output(["lsinfo"]) except Exception: pass try: return subprocess.check_output(["cpuinfo"]) except Exception: pass # For Mac OS X try: return subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"]) except Exception: pass return b"Unknown" def get_hginfo(self): """Return mercurial info about the repo.""" try: hg1 = subprocess.check_output(["hg", "summary"]) hg2 = subprocess.check_output(["hg", "status"]) return hg1 + hg2 except Exception: pass return b"Unknown" def get_gitinfo(self): """Return info about a git repository, if one exists.""" try: git1 = subprocess.check_output(["git", "show", "-s"]) git2 = subprocess.check_output(["git", "status"]) return ( b"HEAD (`git show -s`):\n" + git1 + b"\nStatus (`git status`):\n" + git2 ) except Exception: pass return b"Unknown" def get_jjinfo(self): """Return info about a JJ (Jujutsu) repository, if one exists.""" try: jj1 = subprocess.check_output(["jj", "show", "-s"]) jj2 = subprocess.check_output(["jj", "show", "-s", "@-"]) return ( b"Working copy (`jj show -s`):\n" + jj1 + b"\nParent (`jj show -s @-`):\n" + jj2 ) except Exception: pass return b"Unknown" def get_time(self): """Return the time in a string that can be used in a filename.""" return time.strftime("%Y%m%d-%H%M%S") def get_computer_state(self, timestr=None): """Record the state of the computer, repository, etc. before starting the run. """ hostname = self.get_hostname().decode() cpuinfo = self.get_cpuinfo().decode() hginfo = self.get_hginfo().decode() gitinfo = self.get_gitinfo().decode() jjinfo = self.get_jjinfo().decode() if timestr is None: timestr = self.get_time() computer_state = [ "Starting run at time {}".format(timestr), "sys.argv = {}".format(sys.argv), "hostname = {}".format(hostname), "cpuinfo", "=======", cpuinfo, "", "hginfo", "======", hginfo, "", "gitinfo", "=======", gitinfo, "", "jjinfo", "======", jjinfo, ] return "\n".join(computer_state) @staticmethod def get_human_duration(seconds): """Return a nice human-readable duration""" seconds = int(np.ravel(seconds)[0]) res = [] for msg, seconds_per in [ ("d", 24 * 60 * 60), ("h", 60 * 60), ("m", 60), ("s", 1), ]: if seconds > seconds_per: res.append("{}{}".format(seconds // seconds_per, msg)) seconds = seconds % seconds_per return ", ".join(res) def get_frames(self, tmp=False): """Return the `Frames` instance. Arguments --------- tmp : bool If `True`, then return a temporary instance for accessing properties like Decimal. """ if self._frames is None: args = dict( data_dir=self.dir_name, prefix=self.frame_prefix, checkpoints_to_retain=self.checkpoints_to_retain, mem_limit_bytes=self.mem_limit_bytes, ) frames = Frames(**args) if tmp: return frames self._frames = frames frames = self._frames if self.checkpoint: if not os.path.exists(self.dir_name): self.initialize(create_dir=True) frames.mode = "rw" elif os.path.exists(self.dir_name): frames.mode = "rm" else: frames.mode = "m" return frames frames = property(get_frames) @property def Decimal(self): return self.get_frames(tmp=True).Decimal def delete_data(self): """Remove all saved data from disk.""" if os.path.exists(self.dir_name): shutil.rmtree(self.dir_name) dir_name = os.path.dirname(self.dir_name) while dir_name and not os.listdir(dir_name): # Remove containing folders if they are empty os.rmdir(dir_name) dir_name = os.path.dirname(dir_name) def _unique(self, ts_, max_t_=None, ikeys=None): """Return a sorted array of unique times <= max_t_.""" if max_t_ is None: max_t_ = self.max_t_ Decimal = self.Decimal times_ = list(map(Decimal, ts_)) if times_ and max_t_ is not None and max_t_ < max(times_): self.warning(f"Some times exceed {max_t_=}. Discarding.") times_ = [t for t in times_ if t <= Decimal(max_t_)] if ikeys is None: return np.unique(times_).astype(float) else: times, inds = np.unique(times_, return_index=True) ikeys = [ikeys[i] for i in inds] return times.astype(float), ikeys def get_image_ts_(self): """Return the sorted list of times at which imaging will be run.""" image_ts_ = list(self.experiment.image_ts_) if self.image_ts_ is not None: image_ts_.extend(self.image_ts_) return self._unique(image_ts_) def get_ts_(self, image_t_=None): """Return a sorted list of times at which we want simulation data. Arguments --------- image_t_ : float, None If `None`, then this is the list of times without imaging. Otherwise, it is the list of times up to `image_t_ + experiment.t__image` for the imaging run for `image_t_` """ Decimal = self.Decimal if image_t_ is None: max_t_ = self.max_t_ if max_t_ is None: max_t_ = 0 image_ts_ = self.get_image_ts_().tolist() if image_ts_: max_t_ = max(max(image_ts_), max_t_) else: image_ts_ = [0] T_ = Decimal(max_t_) dt_ = Decimal(self.dt_) N = int(np.ceil(T_ / dt_)) # We don't use np.arange because we want the last element if it is <= T_ ts_ = [n * dt_ for n in range(N + 1) if n * dt_ <= T_] # Explicitly include the image times so that we can directly evolve up to # these without restarting. This can improve the accuracy slightly with the # ABM method which introduces a small error on restarts. ts_.extend(image_ts_) else: max_t_ = image_t_ + self.experiment.t__image T_ = Decimal(max_t_) dt__image = self.dt__image if dt__image is None: dt__image = self.dt_ ts_ = ( np.arange(Decimal(0), Decimal(image_t_), Decimal(self.dt_)).tolist() + np.arange(Decimal(image_t_), T_, Decimal(dt__image)).tolist() ) ts_.append(T_) return self._unique(ts_, max_t_=max_t_) def get_saved_ts_(self, image_t_=None, t__image=None): """Return `(ts_, ikeys)`, lists of times and keys with simulation data. Arguments --------- t__image : None, float If `None`, then return the times along a saved solution path, otherwise return all states that have been imaged for this length of time. image_t_ : None, float If `None`, return the times along the saved solution path without any imaging, otherwise return the saved solution path with imaging starting at `image_t_`. Not used if `t__image` is specified. Returns ------- ts_ : list(float) Sorted array of unique times. ikeys : list(ikey) Corresponding list of ikeys. """ with self.frames as frames: Decimal = self.Decimal max_t_ = self.max_t_ if max_t_ is None: max_t_ = np.inf else: max_t_ = Decimal(max_t_) if t__image is not None: t__image = Decimal(t__image) ikeys = [ ikey for ikey in frames.ikeys() if len(ikey) > 1 and ikey[1] == t__image and ikey[0] <= max_t_ ] ts_ = [self.get_t_(ikey=ikey, frames=frames) for ikey in ikeys] max_t_ = max(ts_) elif image_t_ is None: ikeys = [ikey for ikey in frames.ikeys() if len(ikey) == 1] ts_ = [self.get_t_(ikey=ikey, frames=frames) for ikey in ikeys] max_t_ = self.max_t_ else: ts_ = [] ikeys = [] for ikey in frames.ikeys(): t_ = self.get_t_(ikey=ikey, frames=frames) if (len(ikey) > 1 and ikey[0] == Decimal(image_t_)) or ( len(ikey) == 1 and Decimal(t_) <= Decimal(image_t_) ): ts_.append(t_) ikeys.append(ikey) if ts_: max_t_ = max(ts_) return self._unique(ts_, ikeys=ikeys, max_t_=max_t_) def get_wanted_ts_(self, image_t_=None): """Return a list of wanted ts.""" ts_ = self.get_ts_(image_t_=image_t_) saved_ts_ = set(self.get_saved_ts_(image_t_=image_t_)[0]) return [t_ for t_ in ts_ if t_ not in saved_ts_] def get_solution_path(self, image_t_=None): """Return the solution path. This is the optimal ordering of times to minimize execution assuming that the time to evolve is proportional to the difference in times. """ saved_ts_, ikeys = self.get_saved_ts_(image_t_=image_t_) solution_path = [] computed_ts_ = list(saved_ts_) wanted_ts_ = self.get_wanted_ts_(image_t_=image_t_) while wanted_ts_: dts_ = np.asarray(wanted_ts_)[:, None] - np.asarray(computed_ts_)[None, :] # Find the desired time that is closest to a computed state i = abs(dts_).min(axis=1).argmin() t_ = wanted_ts_[i] # Find the closest computed state if self.allow_negative_dt: raise NotImplementedError() # Don't do this until tested! computation of steps # below fails. i0 = abs(dts_)[i, :].argmin() t0_ = computed_ts_[i0] else: _dts = dts_[i, :] _dts = np.where(_dts < 0, np.inf, _dts) i0 = _dts.argmin() t0_ = computed_ts_[i0] solution_path.append((t0_, t_)) del wanted_ts_[i] computed_ts_.append(t_) return solution_path def get_t_(self, ikey, frames): """Return `t_` for the corresponding `ikey` from `frames`.""" return float(sum(ikey)) def get_ikey(self, t_, image_t_=None): if image_t_ is None or t_ <= image_t_: key = t_ else: key = (image_t_, t_ - image_t_) return self.get_frames(tmp=True).key_to_ikey(key) def set_frame(self, t_, data, image_t_, checkpoint=False): """Set the frame data and return key.""" with self.frames as frames: ikey = self.get_ikey(t_=t_, image_t_=image_t_) if checkpoint: frames.checkpoint(ikey, data) else: frames[ikey] = data return ikey def get_state(self, t_=None, t__image=False, compute=False): """Return the state at time t_ (if it has been computed). Arguments --------- t_ : float Time specifying which state to get. t__image : float, bool If provided, then load the image associated with state after time `t__image`. If simply `True`, then use the `t__image=self.experiment.t__image`. compute : bool If `True` then compute the state from the nearest saves state. """ if t__image: if t__image is True: t__image = self.experiment.t__image image_t_ = t_ t_ = image_t_ + t__image else: image_t_ = None ikey = self.get_ikey(t_=t_, image_t_=image_t_) return self.get_state_from_ikey(ikey=ikey, image_t_=image_t_, compute=compute) def get_state_from_ikey(self, ikey, image_t_=None, compute=False, check_exists=True): """Return the state with key (if it has been computed). Arguments --------- ikey : ikey image_t_ : float Imaging time. This controls `state.t_final`. If `len(ikey) > 1` (i.e. you are requesting a state after some imaging), then `image_t_=ikey[0]`. compute : bool If `True` then compute the state from the nearest saves state. check_exists : bool Check that the state at `ikey` exists, potentially computing it if not. Set this to `False` if you already know the state at `ikey` exists on disk. """ if len(ikey) > 1: ikey_image_t_ = ikey[:1] if ( image_t_ is not None and self.frames.key_to_ikey(image_t_) != ikey_image_t_ ): raise ValueError( f"Imaging time from ikey (={ikey_image_t_}) is different than `image_t_`(={image_t_})" ) image_t_ = float(ikey_image_t_[0]) if not check_exists or ikey in self.get_saved_ts_(image_t_=image_t_)[1]: with self.frames as frames: t_ = self.get_t_(ikey, frames=frames) data = frames[ikey] if image_t_ is None: t__final = np.inf else: t__final = image_t_ state = self.experiment.get_state(initialize=False) state.t_final = float(t__final) * self.experiment.t_unit state.set_data(data) state.t = t_ * self.experiment.t_unit state.initializing = False return state elif not compute: raise IndexError(f"No data for state with {ikey=}.") else: raise NotImplementedError def get_states(self, image_t_=None, t__image=None, progress=False): """Return a list of computed states. Arguments --------- image_t_ : float If provided, then return all states up to this time and all images after this. Otherwise, return all non-imaging states. t__image : float If provided, then get all states that have been imaged after expanding to this time. progress : bool, float If provided, then print progress messages every progress seconds. """ ts_, ikeys = self.get_saved_ts_(image_t_=image_t_, t__image=t__image) if progress is not None: msg = "About to load {} states".format(len(ikeys)) with self.frames as frames: image_ikeys = [ikey for ikey in ikeys if len(ikey) > 1] if image_t_ is not None: msg = " ".join([msg, f"({len(image_ikeys)} during imaging)"]) self.msg(msg) states = [] tic = time.time() tic_last_message = 0 for _n, ikey in enumerate(ikeys): state = self.get_state_from_ikey( ikey=ikey, image_t_=image_t_, check_exists=False ) states.append(state) if progress and time.time() > tic_last_message + progress: toc = tic_last_message = time.time() time_per_state = (toc - tic) / (_n + 1) remaining_states = len(ikeys) - len(states) time_remaining = time_per_state * remaining_states msg = "... {:.1f}s to go ({} states at {:.1f}s/state)".format( time_remaining, remaining_states, time_per_state ) self.msg(msg) print(msg) return states def run(self, image_t_=None): """Run the simulation, optionally checkpointing the data to disk. Arguments --------- image_t_ : float, None If provided, then run the imaging procedure for this time. I.e. we will set `state.t_final = image_t` and then evolve up to `image_t_ + experiment.t__image`. These results will be stored with `key=(image_t, t_)` where `t_` varies in steps of `dt_` up to `experiment.t__image`. """ self.initialize() timestr = self.get_time() computer_state = self.get_computer_state(timestr=timestr) if self.checkpoint: if image_t_ is None: info_file = f"run_{timestr}.txt" else: info_file = f"run_{self.frames.Decimal(image_t_)}_image_{timestr}.txt" info_file = os.path.join(self.dir_name, info_file) with self.msg(f"Writing {info_file=}"): with open(info_file, "w") as f: f.write(computer_state) else: self.info(computer_state) saved_ts_, saved_ikeys = self.get_saved_ts_(image_t_=image_t_) saved = dict(zip(saved_ikeys, saved_ts_)) t_unit = self.experiment.t_unit if not saved: # No existing data. Find initial state with self.msg("Finding initial state"): if self._extends is None: with NoInterrupt(): state = self.experiment.get_initial_state() else: simulation, t_ = self._extends state = self.experiment.get_initialized_state( state=simulation.get_state( t_=t_, image_t_=image_t_ ) # REV: t__image ) t_ = state.t / t_unit with self.msg("Saving initial state (t_={})".format(t_)): self.set_frame(t_=t_, data=state.get_data().copy(), image_t_=image_t_) ikey = self.get_ikey(t_=t_, image_t_=image_t_) saved[ikey] = t_ else: state = self.experiment.get_state() solution_path = self.get_solution_path(image_t_=image_t_) dt = self.dt_t_scale * state.t_scale dt_ = dt / t_unit self.info( "Solution path: {}".format( " ".join( [ "{}->{}".format( self.get_ikey(t_=_t0_, image_t_=image_t_), self.get_ikey(t_=_t_, image_t_=image_t_), ) for (_t0_, _t_) in solution_path ] ) ) ) with NoInterrupt(ignore=True) as interrupted: while solution_path and not interrupted: if self.evolve_times: # Estimate remaining time dts_ = np.diff(solution_path, axis=1) steps = sum(np.ceil(dts_ / dt_).astype(int)) tics_, steps_, times_ = np.asarray(self.evolve_times).T time_per_step = times_ / steps_ time_remaining = steps * time_per_step.mean() time_remaining_err = steps * time_per_step.std() self.info( "Estimated time to completion: {}+-{}".format( self.get_human_duration(time_remaining), self.get_human_duration(time_remaining_err), ) ) t0_, t_ = solution_path.pop(0) state = self.compute_state(t_=t_, image_t_=image_t_) with self.msg("Saving state t_={}".format(t_)): saved[ self.set_frame( t_=t_, data=state.get_data().copy(), image_t_=image_t_ ) ] = t_ def compute_state(self, t_, image_t_=None): """Compute the state at the specified time from the nearest existing state, checkpointing if needed. """ saved_ts_, ikeys = self.get_saved_ts_(image_t_=image_t_) ikey = self.get_ikey(t_=t_, image_t_=image_t_) if ikey in ikeys: return self.get_state_from_ikey(ikey=ikey, image_t_=image_t_) # Find the closest existing state dts_ = t_ - np.asarray(saved_ts_) if self.allow_negative_dt: raise NotImplementedError() # Don't do this until tested! computation of steps # below fails. dts_ = abs(dts_) else: dts_ = np.where(dts_ < 0, np.inf, dts_) i0 = dts_.argmin() t0_, ikey0 = saved_ts_[i0], ikeys[i0] state = self.get_state_from_ikey(ikey=ikey0, image_t_=image_t_) t_unit = self.experiment.t_unit dt = self.dt_t_scale * state.t_scale dt_ = dt / t_unit with self.msg("Evolving from t0_={} to t_={}".format(t0_, t_)): # Evolvers need at least 2 steps if not self.allow_negative_dt: assert t_ > t0_ steps = max(2, int(np.ceil((t_ - t0_) / dt_))) dt_ = (t_ - t0_) / steps dt = dt_ * t_unit evolver = self.Evolver(state, dt=dt) if getattr(evolver, "fixed_dt", True): # Break evolution up for checkpoints. if self.checkpoint_dt_ is not None and self.checkpoint_dt_ < (t_ - t0_): checkpoint_steps = max(2, int(np.floor(self.checkpoint_dt_ / dt_))) stepss = [checkpoint_steps] * (steps // checkpoint_steps) last_steps = steps % checkpoint_steps if last_steps < 2: stepss[-1] += last_steps else: stepss.append(last_steps) assert min(stepss) >= 2 stepss, last_steps = stepss[:-1], stepss[-1] assert sum(stepss) + last_steps == steps else: stepss = [] last_steps = steps tic = time.time() for _steps in stepss: evolver.evolve(_steps) self.set_frame( t_=evolver.t / t_unit, ##### REV Should use get_psi() and set_psi() here. data=evolver.y[...].copy(), checkpoint=True, image_t_=image_t_, ) assert last_steps > 0 evolver.evolve(last_steps) else: # Break evolution up for checkpoints. if self.checkpoint_dt_ is not None and self.checkpoint_dt_ < (t_ - t0_): checkpoint_ts_ = np.arange(t0_, t_, self.checkpoint_dt_) if checkpoint_ts_[-1] >= t_: checkpoint_ts_ = checkpoint_ts_[:-1] else: checkpoint_ts_ = [] tic = time.time() for _t_ in checkpoint_ts_: _t = _t_ * t_unit evolver.evolve_to(_t) assert np.allclose(_t, evolver.t) self.set_frame( t_=_t_, ##### REV Should use get_psi() and set_psi() here. data=evolver.y[...].copy(), checkpoint=True, image_t_=image_t_, ) evolver.evolve_to(t_ * t_unit) self.evolve_times.append((tic, steps, time.time() - tic)) assert np.allclose(evolver.t, t_ * t_unit) return evolver.get_y() def run_images(self): """Make the images from the simulation.""" image_ts_ = self.get_image_ts_() if image_ts_ is None: self.warning( "No image_ts_ specified (got {}). Doing nothing.".format(self.image_ts_) ) return self.initialize() for image_t_ in image_ts_: self.run(image_t_=image_t_) def view(self, plot_state=None, t__image=None, **kw): """Run this to load the generated data and plot it. Arguments --------- plot_state(state): """ from IPython.display import display, clear_output from matplotlib import pyplot as plt if plot_state is None: def plot_state(state, fig=None): try: return state.plot(fig=fig, **kw) except TypeError: if fig is not None: fig = plt.figure(fig.number) else: fig = plt.gcf() state.plot() return fig # t_unit = self.experiment.t_unit # state = self.experiment.get_state() fig = None with NoInterrupt(ignore=True) as interrupted: saved_ts_, ikeys = self.get_saved_ts_(t__image=t__image) for t_ in saved_ts_: plt.clf() fig = plot_state(self.get_state(t_=t_, t__image=t__image), fig=fig) plt.draw() plt.pause(0.01) display(fig) if interrupted: break clear_output(wait=True) class SimulationManager: """Provides access to a bunch of simulations stored on disk. Each set of simulations is accessible through an attribute with the same name as the experiment class. This attribute will return a Simulations() instance with the corresponding simulations. """ def __init__(self, data_dir, simulations=None): self._data_dir = data_dir def find_simulations(self): self._simulation_dirs = simulation_dirs = [ dir_name for dir_name, dirnames, filenames in os.walk(self._data_dir) if "experiment.py" in filenames ] return simulation_dirs def load_simulations(self, simulation_dirs=None): if simulation_dirs is None: simulation_dirs = self.find_simulations() simulations = [] for dir_name in simulation_dirs: try: simulation = Simulation(dir_name=dir_name) simulations.append(simulation) except Exception: traceback.print_exc() self._simulations = Simulations(simulations) names = dict() for s in self._simulations: name = s.experiment.__class__.__name__ names.setdefault(name, []).append(s) for name in names: self.__dict__[name] = Simulations(names[name]) class Simulations(abc.Sequence): """Represents a list of simulations where one can access the simulations through attribute access with tab completion. """ def __init__(self, simulations): self._simulations = simulations self._keys = {} for s in simulations: for key in s.experiment._keys: value = getattr(s.experiment, key) self._keys.setdefault(key, set()).add(value) def __dir__(self): """Custom attribute access for tab completion.""" keys = sorted(_k for _k in self._keys if len(self._keys[_k]) > 1) res = [] for key in keys: new = ["{}[{}]".format(key, value) for value in self._keys[key]] if res: res.extend(["{}.{}".format(_res, _new) for _res in res for _new in new]) else: res.extend(new) return res def __getattr__(self, key): return {value: self.get(**{key: value}) for value in self._keys[key]} def __len__(self): return len(self._simulations) def __getitem__(self, key): return self._simulations[key] def keys(self): keys = set() for s in self.simulations: keys.update(s.experiment._keys) return sorted(keys) def get(self, **kw): """Return the simulations satisfying the specified criteria.""" results = [] for s in self._simulations: if all([getattr(s.experiment, _k) == kw[_k] for _k in kw]): results.append(s) return Simulations(results) def __repr__(self): return "{}():\n{}".format( self.__class__.__name__, "\n".join([_s.dir_name for _s in self._simulations]), ) # try: # # from . import visualize # # class Simulation2(visualize.SimulationMixin2, Simulation): # pass # # except ImportError: # pass ###################################################################### # GPU Support def i_know_this_is_slow(func): """Decorator to suppress PerformanceWarning.""" func._i_know_this_is_slow = True return func class _GPU: """Various decorators for helping with GPU and other accelerator support.""" @staticmethod def _asnumpy(result, instance=None): """Convert result to a numpy array or a tuple of arrays. Uses the following in order: * `instance.asnumpy(result)` if provided * `result.get()` * `cupy.asnumpy(result)` """ if instance: assert hasattr(instance, "asnumpy") asnumpy = instance.asnumpy try: if isinstance(result, tuple): return tuple(map(asnumpy, result)) else: return instance.asnumpy(result) except ValueError: # Special case triggered if the result is a list of inhomogeneous arrays. # Perhaps we should force the user to return a tuple instead? We enforce # this in our code tests. if "PYTEST_CURRENT_TEST" in os.environ: raise try: return tuple(map(asnumpy, result)) except ValueError: pass except AttributeError: pass try: if isinstance(result, tuple): return tuple(_r.get() for _r in result) else: return result.get() except AttributeError: pass try: from cupy import asnumpy except ImportError: from numpy import asarray as asnumpy if isinstance(result, tuple): return tuple(map(asnumpy, result)) else: return asnumpy(result) @wrapt.decorator @classmethod def _return_asnumpy(_GPU, wrapped, instance, args, kwargs): """Decorate the specified function to return a numpy array. To be used on functions that return GPU arrays for performance to generate a method for users to get these as numpy arrays. Uses the following in order: * `state.asynumpy(result)` if 'state' in kwargs * `self.asnumpy(result)` if provided * `result.get()` * `cupy.asnumpy(result)` """ result = wrapped(*args, **kwargs) if hasattr(kwargs.get("state", None), "asnumpy"): instance = kwargs["state"] elif not hasattr(instance, "asnumpy"): instance = None return _GPU._asnumpy(result, instance=instance) @classmethod def from_GPU(_GPU, method_GPU): """Decorator for _GPU methods to fetch the result as a numpy array.""" wrapped = _GPU._return_asnumpy(method_GPU) wrapped._GPU = True return wrapped @classmethod def GPU_delegate(_GPU, method): """Decorator for GPU methods to delegate to the non-GPU method.""" wrapped = _GPU._performance_warning(method) return wrapped @wrapt.decorator @classmethod def _performance_warning(_GPU, wrapped, instance, args, kwargs): """Decorate the specified function warning of performance.""" if not hasattr(wrapped, "_i_know_this_is_slow"): warnings.warn( f"""Default {wrapped.__qualname__}_GPU() delegates to non-GPU version. This is a potential performance issue, even if not using the GPU. To ensure optimal performance, you should make sure that you overload {wrapped.__name__}_GPU() instead. """, category=PerformanceWarning, ) return wrapped(*args, **kwargs) @classmethod def noop_GPU(_GPU, method_GPU): """Decorator to protect _GPU methods from being converted.""" wrapped = _GPU._return_asnumpy(method_GPU) wrapped._GPU = True wrapped._noop_GPU = True return wrapped @staticmethod def default(method_GPU): "Mark with `._default=True` denoting that it delegates to the non-GPU version." method_GPU._default = True return method_GPU @classmethod def add_non_GPU_methods(_GPU, cls): """Decorator that defines non-GPU methods for cls from _GPU methods. Generally, it is an error to define both the non-gpu and GPU methods, and this will raise an AttributeError in this case. The following exceptions are provided. 1. If the method has the attribute `._default = True` (which is set with the `_GPU.default` decorator), then we allow the user to define the non-GPU version even though the GPU version is defined because the default GPU versions delegate to the non-GPU versions. """ warnings.warn( "_GPU.add_non_GPU_methods. Inherit from GPUHelper instead.", DeprecationWarning, ) # Don't dig into subclasses. # CHECK: This is a little suspect, at least with staticmethods. # getattr(cls, name) is the function, but cls.__dict__[name] is the # staticmethod wrapped function. Can we safely wrap a staticmethod, or # do we need to wrap the function, then wrap with staticmethod outside. # The tests work... but should check wrapt docs. names = {name for name in cls.__dict__ if inspect.isfunction(getattr(cls, name))} members = {name: cls.__dict__[name] for name in names} GPU_names = set(name for name in members if name.endswith("_GPU")) non_GPU_names = set(_name[:-4] for _name in GPU_names) duplicates = set( name for name in non_GPU_names.intersection(names) if ( name in members and not hasattr(members[name], "_GPU") and not getattr(members[name], "_default", False) ) ) if duplicates: # It is an error to define both a _GPU and a non_GPU method. This avoids # accidentally overloading the regular method when the _GPU method should be # overloaded, but allows classes with no _GPU methods to still function # normally. raise AttributeError( f"Class {cls.__name__} has _GPU methods but also defines {duplicates}" ) methods_to_add = non_GPU_names.difference(names) for name in methods_to_add: setattr(cls, name, _GPU.from_GPU(members[f"{name}_GPU"])) return cls @classmethod def add_GPU_and_non_GPU_methods(_GPU, cls): """Decorator that defines the GPU/non-GPU pairs for cls from _GPU methods. Generally, it is an error to define both the non-gpu and GPU methods, and this will raise an AttributeError in this case. The following exceptions are provided. 1. If the method has the attribute `._default = True` (which is set with the `_GPU.default` decorator), then we allow the user to define the non-GPU version even though the GPU version is defined because the default GPU versions delegate to the non-GPU versions. """ # CHECK: This is a little suspect, at least with staticmethods. # getattr(cls, name) is the function, but cls.__dict__[name] is the # staticmethod wrapped function. Can we safely wrap a staticmethod, or # do we need to wrap the function, then wrap with staticmethod outside. # The tests work... but should check wrapt docs. names = {name for name in cls.__dict__ if inspect.isfunction(getattr(cls, name))} members = {name: cls.__dict__[name] for name in names} # These are the *_GPU methods defined by the user. These need non-GPU methods # defined. GPU_names = set(name for name in members if name.endswith("_GPU")) non_GPU_names_needed = set(_name[:-4] for _name in GPU_names) # These are the non-GPU methods defined by the user where there is a _GPU method # somewhere. These need GPU methods defined. non_GPU_names = set( name for name in members if not name.endswith("_GPU") and hasattr(cls, name + "_GPU") ) GPU_names_needed = set(name + "_GPU" for name in non_GPU_names) duplicates = set( name for name in non_GPU_names_needed.union(GPU_names_needed).intersection(names) if ( name in members and not hasattr(members[name], "_GPU") and not getattr(members[name], "_default", False) ) ) if duplicates: # It is an error to define both a _GPU and a non_GPU method. This avoids # accidentally overloading the regular method when the _GPU method should be # overloaded, but allows classes with no _GPU methods to still function # normally. raise AttributeError( f"Class {cls.__name__} defines both non-GPU and GPU methods {sorted(duplicates)}" ) for name in non_GPU_names_needed.difference(names): setattr(cls, name, _GPU.from_GPU(members[f"{name}_GPU"])) for GPU_name in GPU_names_needed.difference(names): setattr(cls, GPU_name, _GPU.GPU_delegate(members[f"{GPU_name[:-4]}"])) return cls @staticmethod def _check_class(self): """Check the class for consistency.""" bad = { method.__qualname__: method for name, method in inspect.getmembers(self) if name.endswith("_GPU") and not getattr(method, "_GPU", False) } names = list(bad) classes = set([name.rsplit(".", 1)[0] for name in names]) if bad: cls = self.__class__.__qualname__ raise ValueError( "\n".join( [ f"{names} uninitialized for GPU in {cls}.", "Did you forget @_GPU.add_non_GPU_methods? I.e.:\n", ] + [f"@_GPU.add_non_GPU_methods\nclass {cls}...\n" for cls in classes] ) )
[docs] class GPUHelper: """Class for helping with GPU and other accelerator support. Examples -------- >>> class B(GPUHelper): ... def asnumpy(self, result): ... return result + ['from B.asnumpy'] # Simulate a custom asnumpy method >>> class A(GPUHelper): ... def asnumpy(self, result): ... return result + ['from A.asnumpy'] # Simulate a custom asnumpy method ... ... @staticmethod ... def get_list1_GPU(): ... "Return a list. Non-GPU method will return an array." ... return [1, 2] ... ... def get_list2_GPU(self): ... "Return a list. Non-GPU will use self.asnumpy." ... return [1, 2] ... ... def get_list3_GPU(self, state): ... "Return a list. None-GPU will use state.asnumpy." ... return [1, 2] >>> a = A() >>> a.get_list1_GPU() # GPU method returns list (representing GPU array) [1, 2] >>> a.get_list2_GPU() [1, 2] >>> a.get_list3_GPU(state=B()) [1, 2] >>> a.get_list1() # Non GPU method returns an array. array([1, 2]) >>> a.get_list2() # Non GPU instance method uses A.asnumpy. [1, 2, 'from A.asnumpy'] >>> a.get_list3(state=B()) # Non GPU instance method uses B.asnumpy. [1, 2, 'from B.asnumpy'] >>> set(['get_list1', 'get_list2', 'get_list3']).issubset(dir(a)) True """
[docs] def __init_subclass__(cls): """Prepare the class by defining all non-GPU/GPU pairs "*" and "*_GPU".""" super().__init_subclass__() _GPU.add_GPU_and_non_GPU_methods(cls)
[docs] class AsNumpyMixin(StateMixin, GPUHelper): """Mixin providing asnumpy method."""
[docs] asnumpy = staticmethod(np.asarray)
[docs] def get_data_GPU(self): """Partner to set_data(). This allows access to the data via self.get_data() which is guaranteed to be a number array. """ return self.data[...]
[docs] class StateWithExperimentMixin(StateMixin, GPUHelper): """Mixing to delegate to self.experiment.""" def __init__(self, experiment, **kw):
[docs] self.experiment = experiment
return super().__init__(**kw)
[docs] def get_Vext_GPU(self): return self.experiment.get_Vext_GPU(state=self)
[docs] def get_Vint_GPU(self): return self.experiment.get_Vint_GPU(state=self)
[docs] def get_Eint(self): return self.experiment.get_Eint(state=self)
[docs] def get_n_TF(self, V_TF): return self.experiment.get_n_TF(state=self, V_TF=V_TF)
[docs] def get_ns_TF(self, Vs_TF): return self.experiment.get_ns_TF(state=self, Vs_TF=Vs_TF)