"""Various utilities used throughout the project"""
from __future__ import absolute_import, division, print_function, unicode_literals
import atexit
try:
from collections import abc
except ImportError:
import collections as abc
import contextlib
import decimal
import glob
import inspect
import logging
import math
import os.path
import shutil
import subprocess
import sys
import time
import traceback
import warnings
from six import string_types
import wrapt
import mmfutils.performance.fft
from mmfutils.contexts import NoInterrupt
from persist.objects import Archivable
from persist.archive import Archive
from pytimeode.evolvers import EvolverABM
from pytimeode.mixins import ArrayStateMixin
from pytimeode.interfaces import implementer, IStateForABMEvolvers
import numpy as np
from .interfaces import IExperiment, IStateDFT
from .mixins import StateMixin
__all__ = [
"step",
"x2_2",
"good_Ns",
"get_good_N",
"get_smooth_transition",
"Frames",
"GPUHelper",
"AsNumpyMixin",
"PerformanceWarning",
"IStateDFT",
"StateWithExperimentMixin",
"use_wisdom",
]
_DATA_DIR = "_data"
pauli_matrices = np.array([[[0, 1], [1, 0]], [[0, -1j], [1j, 0]], [[1, 0], [0, -1]]])
_l3 = np.zeros((3, 3, 3))
_l3[0, 1, 2] = _l3[1, 2, 0] = _l3[2, 0, 1] = 1
_l3[2, 1, 0] = _l3[1, 0, 2] = _l3[0, 2, 1] = -1
levi_civita = {2: np.array([[0, 1], [-1, 0]]), 3: _l3}
del _l3
[docs]
def use_wisdom(**kw):
"""Load the fftw_wisdom file (if it exists) and save on exit."""
wisdom_context = mmfutils.performance.fft.fftw_wisdom(**kw)
atexit.register(wisdom_context.__exit__)
wisdom_context.__enter__()
[docs]
def step(t, t1, alpha=3.0):
r"""Smooth step function that goes from 0 at time ``t=0`` to 1 at time
``t=t1``. This step function is $C_\infty$:
"""
if t < 0.0:
return 0.0
elif t < t1:
return (1 + math.tanh(alpha * math.tan(math.pi * (2 * t / t1 - 1) / 2))) / 2
else:
return 1.0
[docs]
def x2_2(x, order=6):
r"""Return a periodic approximation for $x^2/2$ with period $2\pi$.
order == 0: 1-cos(x)
order == 2: cos(x)**2/6 - 4*cos(x)/3 + 7/6
...
"""
if order == 0:
coeffs = [-1.0, 1.0]
elif order == 2:
coeffs = [1.0 / 6, -4.0 / 3, 7.0 / 6]
elif order == 4:
coeffs = [-2.0 / 45, 3.0 / 10, -22.0 / 15, 109.0 / 90]
elif order == 6:
coeffs = [1.0 / 70, -32.0 / 315, 27.0 / 70, -32.0 / 21, 386.0 / 315]
else:
raise NotImplementedError("Got order={}. Must be 0, 2, 4, or 6.".format(order))
return np.polyval(coeffs, np.cos(x))
def x_periodic(x, x0=0.8, p=3):
"""Return a new set of abscissa `xp` such that `V(xp)` will be
smooth and periodic on the interval [-1, 1] if `V(x) = V(-x)` and
`V(x)` is smooth.
Parameters
----------
x0 : float
Parameter that affects the smoothness of the transition. The
resulting abscissa will range from [-x0, x0], so this will
determine the magnitude of `V(x0)` at the boundaries.
p : int
Parameter affecting the smoothness. Should be odd (1 or 3).
Examples
--------
>>> x_periodic([-1.1, -1, 0, 1], x0=0.8)
array([-0.88, -0.8 , 0. , 0.8 ])
"""
a = 1.0 / x0**p
xt = np.sign(x) * np.where(
np.abs(x) < 1,
np.maximum(0, (np.tanh(2 * a / np.pi * np.tan(np.pi * np.abs(x) ** p / 2.0)) / a))
** (1.0 / p),
np.abs(x) * x0,
)
return xt
[docs]
def good_Ns(Nmax=2**15):
"""Return a list of good N's for the FFT (powers of 2, 3, and 5)."""
factors = [2, 3, 5]
max_powers = np.ceil(np.log(Nmax) / np.log(factors)).astype(int)
terms = np.meshgrid(
*[_f ** np.arange(_p + 1) for _f, _p in zip(factors, max_powers)],
sparse=True,
indexing="ij",
)
if False:
# The following now fails with a VisibleDeprecationWarning:
# "Creating an ndarray from ragged nested sequences (which is a
# list-or-tuple of lists-or-tuples-or ndarrays with different
# lengths or shapes) is deprecated.
# ...
# https://stackoverflow.com/a/65982550/1088938
# math.prod works for python >= 3.8
res = sorted(np.prod(terms, axis=0).ravel())[1:]
else:
res = 1
for term in terms:
res = res * term
res = sorted(res.ravel())[1:]
return np.array([_n for _n in res if _n <= Nmax])
[docs]
def get_good_N(N):
"""Get the lowest good size N greater than or equal to N for the FFT.
Examples
--------
>>> get_good_N(600)
600
>>> get_good_N(601)
625
"""
Ns = good_Ns(2 * N)
return int(Ns[np.where(Ns >= N)[0][0]])
def mem_str(bytes):
"""Return the memory usage in nice units.
Examples
--------
>>> [str(mem_str(_d)) for _d in
... [1.0, 1000, 2000, 2*1024**2, 3.1*1024**3,
... 4*1024**4, 5*1024**5, 6*1024**6]]
['1B', '1000B', '1.95kB', '2MB', '3.1GB', '4TB', '5PB', '6EB']
>>> print(mem_str(123913))
121kB
"""
bytes = int(bytes)
power = min(int(np.log(bytes) / np.log(1024)), 6)
if bytes % 1024**power == 0:
mem = str(bytes // 1024**power)
else:
mem = "{:.3g}".format(bytes / 1024.0**power)
return "{}{}B".format(mem, ["", "k", "M", "G", "T", "P", "E"][power])
def hex_mantissa(x):
"""Return the hex mantissa for x.
Used to ensure that parameters have an exact representation in floating
point.
Examples
--------
>>> print(hex_mantissa(0.1))
1.999999999999a
>>> print(hex_mantissa(0.25))
1.
"""
mantissa = float(x).hex()[2:].split("p")[0].rstrip("0")
return mantissa
[docs]
def get_smooth_transition(fs, durations, transitions, alphas=None):
"""Return a C(inf) smooth transition as a function of t.
Smoothly transition from fs[0] to fs[1] to fs[2] etc. and hold these for
times `ts[0]`, `ts[1]`, respectively starting from `t=0`. The transitions
take time `dts[0]` etc.
Arguments
---------
fs : [float]
List of `N` function values.
durations : [float]
List of `N-1` durations for each function values (the last value will
be held indefinitely).
transitions : [float]
List of `N-1` transition durations.
alphas : [float]
List of `N-1` alpha values for each transition.
"""
fs = np.asarray(fs)
if np.issubdtype(fs.dtype, np.integer):
# Fix issue #9. If fs is an integer type, then vectorize might allocate an
# integer output array. Here we make sure it is at least a float... but don't
# convert complex arrays to floats by mistake!
fs = fs.astype(float)
N = len(fs)
if alphas is None:
alphas = [1.0] * (N - 1)
assert N - 1 <= len(durations)
assert N - 1 == len(alphas)
durations = durations[: N - 1]
assert N - 1 == len(transitions)
ts = np.empty(2 * N - 2)
ts[::2] = durations
ts[1::2] = transitions
ts = np.cumsum(ts)
@np.vectorize(otypes=[fs.dtype])
def smooth_transition(t):
if t <= ts[0]:
# Special case for t < t_0
return fs[0]
if t >= ts[-1]:
# Special case for t > t_max
return fs[-1]
# Index of right time interval
i1 = np.where(t < ts)[0][0]
i0 = i1 - 1
if 0 == i1 % 2:
# Hold interval
ind = i1 // 2
return fs[ind]
else:
# Transition interval
ind = i1 // 2
f0, f1 = fs[ind : ind + 2]
s = step(t - ts[i0], transitions[ind], alpha=alphas[ind])
return f1 * s + f0 * (1 - s)
return smooth_transition
[docs]
def evolve_to(
state, t, Evolver=EvolverABM, dt_t_scale=0.1, callback=None, plot_steps=100
):
"""Evolve state to time t.
Arguments
---------
callback : None, function
If provided, then call callback(state) every plot_steps steps.
plot_steps : int
Steps to take between plots.
"""
dt = dt_t_scale * state.t_scale
t_max = t - state.t
steps = int(np.ceil(t_max / dt))
dt = t_max / steps
evolver = Evolver(state, dt=dt)
if callback is not None:
stepss = (plot_steps,) * (steps // plot_steps) + (steps % plot_steps,)
if stepss[-1] == 1:
stepss[-2] -= 1
stepss[-1] += 1
for _steps in stepss:
evolver.evolve(_steps)
callback(evolver.get_y())
else:
evolver.evolve(steps)
return evolver.y
[docs]
def evolve(
state=None,
history=None,
Evolver=EvolverABM,
t_max=np.inf,
dt_t_scale=0.1,
steps=100,
display=True,
):
"""Iterator that runs an evolver.
Parameters
----------
state : IState
State to evolve.
history : [State]
List of states. If provided, then use the last state to
start. This is mutated.
Evolver : IEvolver
Evolver to use.
t_max : float
Maximum time to evolve to. (Default - evolve until
interrupted.)
dt_t_scale : float
Size of time-step in units of state.t_scale.
steps : int
Steps to evolve between yielded states.
display : bool
If True, then display the current figure (using plt.gcf()).
Example:
for y in evolve(s):
plt.clf()
y.plot()
"""
if state is None:
state = history[-1]
else:
try:
len(state)
history = state
state = history[-1]
except TypeError:
history = [state]
dt = dt_t_scale * state.t_scale
if t_max < np.inf:
steps_ = int(np.ceil(t_max / dt))
dt = t_max / steps_
steps_ = steps
evolver = Evolver(state, dt=dt)
yield evolver.y
with NoInterrupt(ignore=True) as interrupted:
while evolver.t < t_max and not interrupted:
if t_max < np.inf:
steps_ = min(steps, int((t_max - evolver.t) / dt))
evolver.evolve(steps_)
history.append(evolver.get_y())
yield history[-1]
if steps_ < steps:
break
if display:
from matplotlib import pyplot as plt
import IPython.display
IPython.display.display(plt.gcf())
IPython.display.clear_output(wait=True)
[docs]
def evolves(
states=None,
histories=None,
Evolver=EvolverABM,
t_max=np.inf,
dt_t_scale=0.1,
steps=100,
fig=None,
display=True,
):
"""Iterator that runs an evolver like evolve() but with multiple states.
Parameters
----------
states : IState
List of states to evolve. The first of these is used as the
reference state for units such as t_scale etc.
histories : [[State]]
List of histories: each is a list of states. If provided, then
use the last states to start. These lists are mutated.
Evolver : IEvolver
Evolver to use.
t_max : float
Maximum time to evolve to. (Default - evolve until
interrupted.)
dt_t_scale : float
Size of time-step in units of state.t_scale.
steps : int
Steps to evolve between yielded states.
display : bool
If True, then display the current figure (using fig or plt.gcf()).
"""
if states is None:
states = [_history[-1] for _history in histories]
else:
histories = [[_state] for _state in states]
state = states[0]
dt = dt_t_scale * state.t_scale
if t_max < np.inf:
steps_ = int(np.ceil(t_max / dt))
dt = t_max / steps_
steps_ = steps
evolvers = [Evolver(_s, dt=dt) for _s in states]
evolver = evolvers[0]
yield [_e.y for _e in evolvers]
with NoInterrupt(ignore=True) as interrupted:
while evolver.t < t_max and not interrupted:
if t_max < np.inf:
steps_ = min(steps, int((t_max - evolver.t) / dt))
[_e.evolve(steps_) for _e in evolvers]
[_h.append(_e.get_y()) for _e, _h in zip(evolvers, histories)]
yield [_h[-1] for _h in histories]
if steps_ < steps:
break
if display:
from matplotlib import pyplot as plt
import IPython.display
if fig is None:
fig = plt.gcf()
IPython.display.display(fig)
IPython.display.clear_output(wait=True)
[docs]
class Frames:
"""Represents a series of frames (arrays) on disk for checkpointing and
making movies. Each frame is stored in a file with a key that is usually
the time at which the frame is valid.
The main interface is through item access, i.e.::
state.set_data(frames[key])
with frames as frames:
frames[key] = state.get_data()
Checkpoints should be created explicitly::
with frames as frames:
frames.checkpoint(key, state.get_data())
These will be treated as regular frames, but will be deleted when
a new checkpoint is created.
Context
=======
Frames instances should generally be used as a context if writing to disk. This
will suspend immediate mode until the end of the context, improving performance.
Key Conversions
===============
To ensure safe comparisons between frame keys, we convert to and
from an ikey with the methods `key_to_ikey()` and `ikey_to_key()`.
These should perform appropriate manipulations like rounding so
that equality comparison between keys is meaningful. (Direct
comparison of floating point values is dangerous since round-off
error could cause a key failure between keys that are practically
the same, but obtained by slightly different orders of operations.)
key : This is what the user supplies as a key. In the default
implementation this is either a floating point number or a
tuple of floating point numbers.
ikey : This is the internal representation of the key, use for
comparison, indexing, etc. In the default implementation, we
use the Decimal() class to truncate and provide an exact
representation of the floating point number.
Finally, the ikey needs to be converted to a string for use in the
filenames. These conversions are done by the ikey_to_str() and
str_to_ikey() methods. All four conversion methods should be
redefined if a different type of key is used.
The default implementation produces filenames such as::
frame_0.1000_image_0.0500.npy
which would represent a frame evolved after 0.1 time units and
then imaged after an additional 0.05 time units of expansion.
Attributes
----------
data_dir : str
The frames will be stored in this directory.
mode : str containing 'r', 'w', 'm'
Read ('r'), write ('w'), or in-memory ('m'). Data not written to disk unless
'w' in mode. Setting a frame without 'w' issues a warning that the data will
not be stored on disk. This can be suppressed by setting the mode to 'm' which
indicates in-memory mode. Setting `mode='m'` will not access the disk at all
(even read access will be disabled).
The default `mode=''` is 'w' in a context and 'r' outside a context.
prefix : str
Frame filenames start with this.
sep : str
When the key is iterable (as defined by `self.isiterable()`),
then keys are joined by this separator. The default is `"_image_"`.
immediate : bool, None
If `True`, then frames will be saved to disk upon assignment, otherwise
they will be saved upon `flush()` (also called at the end of a context.)
The default (`None`) is False if used in a context, but True
otherwise.
checkpoints_to_retain : int
When saving a checkpoint, keep this many previous checkpoints
and delete the rest.
mem_limit_bytes : None (default) or int
If not `None`, limit in bytes imposed by `self.limit_memory()` on the
total size the frames may take up in memory. If `self.immediate is
None` (default), `limit_memory` is called upon exiting a context. If
`self.immediate`, `limit_memory` is automatically called whenever a
frame is set or loaded. Otherwise, `limit_memory` must be invoked
manually.
decimal_precision : int
This is a specialized argument for the default version of the class.
Internally, we use the Decimal class for keys, normalized to this
precision. If precision is lost (according to `np.allclose`, then a
ValueError is raised. When loading data from a file, this may be
increased to prevent loss of precision.
Examples
--------
Notes
-----
* If `del` is called, the underlying file will be removed immediately, even
if the immediate flag is False.
* The list of `keys()` will only be updated at the start of a context, the
first time it is called, or after `flush()`.
* One can only set an item if it is not already existing. (Call del first
if needed.)
"""
def __init__(
self,
data_dir,
mode="",
prefix="frame_",
sep="_image_",
checkpoint_prefix="check_",
checkpoints_to_retain=1,
immediate=None,
mem_limit_bytes=None,
decimal_precision=4,
):
[docs]
self.data_dir = data_dir
[docs]
self.checkpoint_prefix = checkpoint_prefix
[docs]
self.checkpoints_to_retain = checkpoints_to_retain
self.mode = mode
[docs]
self.mem_limit_bytes = mem_limit_bytes
[docs]
self.decimal_precision = decimal_precision
[docs]
self._context = False
@property
[docs]
def mode(self):
"""Return the mode, making context corrections for defaults."""
mode = self._mode
if mode == "":
mode = "w" if self._context else "r"
return mode
@mode.setter
def mode(self, mode):
self._mode = mode.lower()
[docs]
def __enter__(self):
"""Enter the context."""
if self._context:
raise NotImplementedError("Nested contexts not supported.")
self._context = True
if set("wr").intersection(self.mode):
self._get_data_dir()
self.keys(update=True)
return self
[docs]
def __exit__(self, exc_type, exc_val, exc_tb):
try:
with NoInterrupt(ignore=False):
if "w" in self.mode:
if hasattr(self, "_issue_2"):
import signal
os.kill(os.getpid(), signal.SIGINT)
self.flush()
self.limit_memory()
finally:
self._context = False
######################################################################
# Customizations
#
# These methods perform various conversions between key, ikey, and
# the filename str. Subclasses using keys that are not floats
# should redefine all of these methods to be consistent.
[docs]
def isiterable(self, key):
"""Return `True` if the key represents a list or tuple of subkeys."""
return isinstance(key, abc.Iterable) and not isinstance(key, string_types)
[docs]
def Decimal(self, key, increase_precision=False):
"""Return `Decimal(key)` rounded to `self.decimal_precision`.
Arguments
---------
increase_precision : bool
If `True`, then increase as needed `self.decimal_precision` to
ensure accuracy of the keys, otherwise, raise `ValueError` if
precision is lost.
"""
if isinstance(key, np.integer):
key = int(key)
elif isinstance(key, np.floating):
key = float(key)
t = decimal.Decimal(key)
if increase_precision:
self.decimal_precision = max(self.decimal_precision, -t.as_tuple().exponent)
else:
t = t.quantize(decimal.Decimal(10) ** (-self.decimal_precision))
if not np.allclose(float(t), float(key), rtol=1e-12, atol=1e-12):
raise ValueError(
"Precision lost when converting key: {}->{}".format(
repr(key), repr(t)
)
)
return t
[docs]
def key_to_ikey(self, key):
"""Convert user-supplied key to internal format."""
if self.isiterable(key):
ikey = tuple(self.Decimal(_k) for _k in key)
else:
ikey = (self.Decimal(key),)
return ikey
[docs]
def ikey_to_key(self, ikey):
"""Convert internal key format to user format."""
if ikey is None:
return ()
elif len(ikey) == 1:
return float(ikey[0])
else:
return tuple(map(float, ikey))
[docs]
def ikey_to_str(self, ikey):
"""Convert internal key to string for filename."""
fmt = "{}".format
return self.sep.join(map(fmt, ikey))
[docs]
def str_to_ikey(self, key_string):
"""Convert string to internal key, increasing `decimal_precision`
if required.
"""
ikey = tuple(
self.Decimal(_k, increase_precision=True) for _k in key_string.split(self.sep)
)
return ikey
# End of customizable methods
######################################################################
[docs]
def __contains__(self, key):
return self.key_to_ikey(key) in self.ikeys()
[docs]
def ikeys(self, update=False):
"""Return a list of available ikeys.
Arguments
---------
update : bool
If True, then reset the list of previously loaded files,
and reread data from disk. If False, then only previously
detected frames and computed frames will be seen - any
additional frames (i.e. saved by another process) will not
be detected.
"""
if update and "_files" in self.__dict__:
del self.__dict__["_files"]
return sorted(set(self._files).union(self._data))
[docs]
def keys(self, update=False):
"""Return a list of available keys.
Arguments
---------
update : bool
If True, then reset the list of previously loaded files,
and reread data from disk. If False, then only previously
detected frames and computed frames will be seen - any
additional frames (i.e. saved by another process) will not
be detected.
"""
return list(map(self.ikey_to_key, self.ikeys(update=update)))
[docs]
def __getitem__(self, key):
ikey = self.key_to_ikey(key)
if ikey not in self._data and ikey in self._files:
filename = self._files[ikey]
if not os.path.exists(filename):
raise LookupError(
r"\n".join(
[
"File {} for key={} has gone missing!",
"Do you need to run git annex?",
"",
" git annex get {}",
]
).format(filename, key, os.path.dirname(filename))
)
value = np.load(filename)
self._data[ikey] = value
value = self._data[ikey]
if self.immediate:
self.limit_memory()
return value
[docs]
def checkpoint(self, key, value):
"""Save data to file, but as a checkpoint."""
self.__setitem__(key, value, checkpoint=True)
[docs]
def _is_checkpoint(self, ikey):
"""Return True if key is saved as a checkpoint."""
return ikey in self._files and (
os.path.basename(self._files[ikey]).startswith(self.checkpoint_prefix)
)
[docs]
def _convert_checkpoint(self, ikey):
"""Convert a saved checkpoint to a real frame.
Return True if if checkpoint file was converted.
"""
if not self._is_checkpoint(ikey=ikey):
return False
filename = self._get_filename(ikey=ikey, data_dir=self.data_dir, checkpoint=False)
os.rename(self._files[ikey], filename)
self._files[ikey] = filename
return True
[docs]
def __setitem__(self, key, value, checkpoint=False):
ikey = self.key_to_ikey(key)
mode = self.mode
if not set("wm").intersection(mode):
warnings.warn(f"Setting Frame[{key}] in {mode=} will not get saved to disk!")
if self._convert_checkpoint(ikey=ikey):
# Special case of key already being a checkpoint file:
assert ikey in self._data and ikey in self._files
elif ikey in self._data:
raise LookupError(
"Data for key={} already set. (Call del first.)".format(key)
)
elif ikey in self._files:
raise LookupError(
"File for key={} already exists. (Call del first.)".format(key)
)
else:
self._data[ikey] = np.asarray(value)
# REV: Make sure all cases here are tested!
immediate = self.immediate or (self.immediate is None and not self._context)
if (immediate or checkpoint):
if "w" in mode:
self._save(ikey, value, checkpoint=checkpoint)
self.limit_memory()
[docs]
def __delitem__(self, key, ikey=None):
if ikey is None:
ikey = self.key_to_ikey(key)
if ikey not in self._files and ikey not in self._data:
raise LookupError(key)
filename = self._files.get(
ikey, self._get_filename(ikey=ikey, data_dir=self.data_dir)
)
if filename and os.path.exists(filename):
os.remove(filename)
if ikey in self._files:
del self._files[ikey]
if ikey in self._data:
del self._data[ikey]
[docs]
def flush(self):
"""Write current data to disk."""
mode = self.mode
if "w" not in self.mode:
raise ValueError(f"Cannot flush data in {mode=}")
data_dir = self._get_data_dir()
self._remove_checkpoints()
for ikey in self._data:
if ikey in self._files and os.path.exists(self._files[ikey]):
continue
self._save(ikey, value=self._data[ikey], data_dir=data_dir)
self.keys(update=True)
[docs]
def get_mem_size(self):
"""Return the size of the frames stored in memory in bytes.
Note that this only calls `sys.getsizeof()` on each value of
`self._data`; the actual size of a Frames instance will be slightly
bigger.
"""
out = 0
seen = []
for val in self._data.values():
if id(val) in seen:
# Don't double count frames that are the same object
continue
seen.append(id(val))
out += sys.getsizeof(val)
return out
[docs]
def limit_memory(self, update_files=True):
"""Drop the oldest frames from memory to fit under the memory limit."""
if self.mem_limit_bytes is None:
return
while self.get_mem_size() > self.mem_limit_bytes:
if update_files and "_files" in self.__dict__:
# Refresh list of files on disk
del self.__dict__["_files"]
# Since python 3.7, dict_keys are ordered according to when they
# were added, so we just need the first key to get the oldest frame.
oldest_ikey = next(iter(self._data.keys()))
if oldest_ikey not in self._files:
key = self.key_to_ikey(oldest_ikey)
raise MemoryError(
f"`mem_limit_bytes` (={self.mem_limit_bytes}) exceeded, and Frame[{key}] cannot be dropped because it is not saved on disk"
)
self._data.pop(oldest_ikey)
######################################################################
# Private methods
@property
[docs]
def _prefixes(self):
"""Return a tuple of the possible filename prefixes."""
return (self.prefix, self.checkpoint_prefix)
[docs]
def _filename_to_ikey(self, filename):
"""Extract the time key from the filename."""
if not filename:
return ()
framename = os.path.basename(filename)
key_string = None
for prefix in self._prefixes:
if framename.startswith(prefix):
key_string = framename[len(prefix) : -4]
break
if key_string is None:
raise ValueError(f"Unknown file prefix for frame {framename}")
try:
ikey = self.str_to_ikey(key_string=key_string)
except (decimal.InvalidOperation, ValueError):
ikey = ()
warnings.warn(
"Skipping malformed frame name {} (should be {}).".format(
framename, prefix + "###.####.npy"
)
)
return ikey
@property
[docs]
def _files(self):
mode = self.mode
if "m" in mode:
return {} # Only in-memory use, so no files.
if "_files" not in self.__dict__:
all_files = []
for _prefix in self._prefixes:
all_files.extend(
glob.glob(os.path.join(self._get_data_dir(), _prefix + "*.npy"))
)
self.__dict__["_files"] = {
self._filename_to_ikey(filename=_f): _f for _f in all_files
}
for _ikey in list(self._data):
if _ikey not in self._files:
del self[_ikey]
return self.__dict__["_files"]
[docs]
def _get_data_dir(self):
"""Return `data_dir`, checking that it exists, is a directory, and making it if
needed and 'w' in mode.
"""
mode = self.mode
if mode == "m":
raise ValueError(f"No get data_dir in {mode=}.")
data_dir = self.data_dir
if not os.path.exists(data_dir):
if "w" in mode:
os.makedirs(data_dir)
else:
raise ValueError(
f"Cannot read ({mode=}) since directory {data_dir=} does not exist."
)
elif not os.path.isdir(data_dir):
raise IOError(f"Specified {data_dir=} is not a directory.")
return data_dir
[docs]
def get_filename(self, key, checkpoint=False):
"""Return the filename associated with key."""
return self._get_filename(
ikey=self.key_to_ikey(key=key),
data_dir=self.data_dir,
checkpoint=checkpoint,
)
[docs]
def _get_filename(self, ikey, data_dir, checkpoint=False):
"""Return the filename associated with key."""
if checkpoint:
prefix = self.checkpoint_prefix
else:
prefix = self.prefix
name = prefix + self.ikey_to_str(ikey=ikey)
return os.path.join(data_dir, name + ".npy")
[docs]
def _save(self, ikey, value, data_dir=None, checkpoint=False):
if data_dir is None:
data_dir = self._get_data_dir()
filename = self._get_filename(ikey=ikey, data_dir=data_dir, checkpoint=checkpoint)
if os.path.exists(filename):
raise IOError("File {} already exists!".format(filename))
if checkpoint:
# Remove old checkpoints
self._remove_checkpoints(checkpoints_to_retain=self.checkpoints_to_retain - 1)
np.save(filename, value)
self._files[ikey] = filename
[docs]
def _remove_checkpoints(self, checkpoints_to_retain=None):
"""Remove old checkpoints in key order."""
if checkpoints_to_retain is None:
checkpoints_to_retain = self.checkpoints_to_retain
checkpoints = sorted([ikey for ikey in self._files if self._is_checkpoint(ikey)])
for ikey in checkpoints[: len(checkpoints) - checkpoints_to_retain]:
self.__delitem__(key=None, ikey=ikey)
@implementer(IExperiment)
class ExperimentBase(Archivable):
"""Base for Experiment classes.
Inherit from this class to provide an interface to the problem. It's main role
should be to accept experimentally relevant parameters, and then produce an
appropriate initial state that representing the experimental protocol.
See `gpe.utils.ExperimentExample` for a demonstration of how to use
this class.
Note: All times are expressed in terms of `t_unit` such that `t_ =
t/t_unit`. We try to consistently use the name `t_` for such dimensionless
quantities except in class variables which are assumed to be in the
specified `t_unit`.
1. Also provides a mechanism for recording time-dependent parameters.
These should be provided through a set of methods names `*_t_()` which
take the normalized time `t_` as an argument and return the
time-dependent value of this parameter. For plotting purposes an
accompanying method `*_info()` should be defined which returns the
corresponding unit value and a label.
2. Internal methods should use the `get(param, t_)` method which will
delegate to the appropriate time-dependent function if it exists (or
fall back to the basic parameter access).
Simulations and Imaging
=======================
The idea of an "Experiment" is some sort of simulation run defined by a set
of parameters (attributes of this class) that is evolved through a set of
`image_ts_` under a set of "normal" experimental conditions. These states would be
what is observed in "in situ imaging". Typically, however, from these states, one
evolves for an additional time `t__image` without any interactions or traps to allow
the clouds to "expand", after which an "expansion image" is taken, usually resolving
better details like vortices and domain walls.
"""
State = NotImplemented # State class used
t_unit = 1.0 # Conversion factor for time units
t_name = "" # Optional name
image_ts_ = () # Times at which experiment starts imaging.
t__image = 0 # Length of expansion imaging.
_initializing = False
@property
def t__final(self):
"""Return `t__final = t_final/t_unit`."""
if not self._initializing:
# We do some inspection when initializing, so don't warn then.
warnings.warn(
"t__final is deprecated: please use image_ts_ or state.t_final",
DeprecationWarning,
stacklevel=2,
)
if len(self.image_ts_) > 0:
return max(self.image_ts_)
return None
@t__final.setter
def t__final(self, t__final):
warnings.warn(
"experiment.t__final is deprecated: please use state.t_final",
DeprecationWarning,
stacklevel=2,
)
self.image_ts_ = tuple(sorted(set(self.image_ts_).union({t__final})))
# These control dir_name.
max_key_length = 3 # Limit key lengths (if not directory_per_key)
sort_keys = True # Sort keys for unique dir_name (issue #4).
# only set False to emulate old behavior.
key_order = () # If provided, these keys come first.
directory_per_key = True # Use a directory for each key (shorter filenames)
# For debugging purposes we keep track of which keys are used so that
# spelling mistakes can be checked for (see `_unused_keys()`). Here we
# pre-populate this set with some keys that are only used on saving etc.
_used_keys = ["State", "t_name", "t_unit"]
# We also have a set of special keys that should not be used in
# dir_name.
_special_keys = set(
[
"image_ts_",
"t__image",
"max_key_length",
"directory_per_key",
"sort_keys",
"key_order",
]
)
_deprecated_keys = set(["t__final"])
# Deprecated keys are only allowed to be defined in the class with this
# full name (currently `'gpe.utils.ExperimentBase'`)
_base_class_name = __module__ + "." + __qualname__
def __init__(self, _local_dict=None, **kw):
"""Constructor.
The constructor takes a dictionary of the local variables passed
to the subclass. These will be assigned to variables of the same name in
`__dict__` and stored to generate the `dir_name` where data will be
stored. Names starting with `_` will be ignored.
The `kw` argument is provided to allow subclasses to pass in additional
parameters.
Note: We generally recommend that users DO NOT overload the
constructor for several reasons. Although one can use this approach to
define additional parameters, this use is generally discouraged for the
following reasons:
1. These will override any parameters defined at the class level -
EVEN IF OVERRIDDEN IN SUBCLASSES. If you want subclasses to
override these parameters, they MUST do so in their `__init__()`
method.
2. These parameters will ALWAYS be included in the directory name,
even if the default values are used. This behavior can be achived
anyway by simply passing the default value to the default
constructor.
If you proceed with this approach, use the following model::
def __init__(self, a=0.0, **_kw):
super().__init__(locals(), **_kw)
Make sure you pass through `_kw` with an underscore so it does not get
set as an attribute. The following approach of using `locals()` allows
you to forgo repeating the keyword arguments you specify in the signature.
Do not do any initialization here - only set parameters. Initialization
should be performed in the `init()` method which will be called by
`ExperimentBase.__init__()`.
"""
self._initializing = True
# This is a list of all parameters known to the class. For debugging
# purposes, we do not allow the user to set additional parameters in
# the constrctor, which are usually "spelling" mistakes (see issue 1).
# This is a little different from self._keys which are all the "active"
# keys that are used in the file name etc. Only active keys can be set
# after construction.
_used_keys = set()
for kls in inspect.getmro(self.__class__):
_used_keys.update(getattr(kls, "_used_keys", []))
self._known_keys = sorted(
[
_key
for _key, _value in inspect.getmembers(
self, lambda a: not (inspect.isroutine(a))
)
if not _key.startswith("_") and _key not in self._deprecated_keys
]
)
# Set this *after* doing the inspect stuff because it uses all the keys!
self._used_keys = _used_keys
if _local_dict is None:
args = dict(kw)
else:
args = dict(_local_dict, **kw)
args = {_k: args[_k] for _k in args if (not _k.startswith("_") and _k != "self")}
keys = set(args).union(self._special_keys)
# Check if any child classes re-define a deprecated key
family = inspect.getmro(self.__class__)
family_names = [kls.__module__ + "." + kls.__qualname__ for kls in family]
base_class = family[family_names.index(self._base_class_name)]
children = [c for c in family if issubclass(c, base_class) and c != base_class]
for child in children:
redefined_deprecated_keys = set(child.__dict__).intersection(
self._deprecated_keys
)
if redefined_deprecated_keys:
warnings.warn(
("{} redefines deprecated key(s) {}.").format(
child.__name__,
sorted(redefined_deprecated_keys),
),
DeprecationWarning,
stacklevel=2,
)
# Check if any arguments are deprecated
deprecated_keys = set(args).intersection(self._deprecated_keys)
if deprecated_keys:
warnings.warn(
("{} got deprecated keyword argument(s) {}.").format(
self.__class__.__name__,
sorted(deprecated_keys),
),
DeprecationWarning,
stacklevel=2,
)
# Check that all arguments are indeed attributes
unknown_keys = set(args).difference(
set(self._known_keys).union(self._deprecated_keys)
)
if unknown_keys:
raise ValueError(
(
"{} got unexpected keyword argument(s) {}.\n" + "(Known keys: {})"
).format(self.__class__.__name__, sorted(unknown_keys), self._known_keys)
)
self._keys = sorted(keys)
for _k in args:
setattr(self, _k, args[_k])
# Placing this before setting _initializing=False allows
# attributes to be set by self.init()
self.init()
self._initializing = False
def _unused_keys(self):
"""Return set of unused keys (parameters).
This is for debugging purposes to make sure that parameters are not
misspelled."""
used_keys = set.union(self._used_keys, self._special_keys)
return set(self._known_keys).difference(used_keys)
def items(self):
"""Provides support for Archivable."""
return [(_k, getattr(self, _k)) for _k in self._keys]
def get_diff(self, experiment):
"""Return a list of differing keys (excludes _special_keys)."""
if experiment is self:
return []
assert set(self._known_keys) == set(experiment._known_keys)
set_keys = set(self._keys).union(experiment._keys).difference(self._special_keys)
return [
key for key in set_keys if not getattr(self, key) == getattr(experiment, key)
]
def init(self):
"""Overload this to perform any initial computations."""
# Save and restore _used_keys so inspect call does used
# everything.
_used_keys = set(self._used_keys)
members = inspect.getmembers(self, predicate=callable)
self._used_keys = _used_keys
self.time_dependent_parameters = {
_name[:-3]: _meth for (_name, _meth) in members if _name.endswith("_t_")
}
for name in self.time_dependent_parameters:
# method_name = name + "_t_"
info_name = name + "_info"
if not hasattr(self, info_name):
setattr(self, info_name, (1.0, name))
def __getattribute__(self, key):
val = super(ExperimentBase, self).__getattribute__(key)
if not key.startswith("_") and not (inspect.isroutine(val)):
self.__dict__.setdefault("_used_keys", set()).add(key)
return val
def __setattr__(self, key, value):
"""Prohibit setting of non-key attributes."""
if key.startswith("_") or self._initializing or key in self._keys:
self.__dict__[key] = value
else:
raise AttributeError("Cannot set {}: only {}".format(key, self._keys))
@property
def dir_name(self):
"""Return the name of the directory in which to store checkpoint data.
This name should encode all of the parameter values. The default
version here uses a nested structure with the class name and then the
parameters: `<Experiment>/<param1=...,param2=...,...>/`
To ensure that the directories are unique, we first use the
list param_order, then sort alphabetically.
"""
keys = set(self._keys).difference(self._special_keys)
sorted_keys = [_key for _key in self.key_order if _key in keys]
new_keys = keys.difference(sorted_keys)
if self.sort_keys:
new_keys = sorted(new_keys)
sorted_keys.extend(new_keys)
# These should be unique.
keys = sorted_keys
if not self.directory_per_key:
short_keys = []
for k in keys:
if len(k) > self.max_key_length:
k = k[: self.max_key_length].strip(" -_")
if k in short_keys:
suffixes = (
"0123456780"
+ "abcdefghijklmnopqrstuvwxyz"
+ "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
)
short_k = k[:-1]
for _s in suffixes:
if short_k + _s not in short_keys:
k = short_k + _s
if k in short_keys:
raise ValueError(
"Could not find unique {}-letter name for parameter {}.".format(
self.max_key_length, k
)
+ " (Collides with {})".format(short_keys)
)
short_keys.append(k)
else:
short_keys = keys
key_pairs = [
"{}={}".format(_k, getattr(self, _key)).replace(
" ", ""
) # Remove inner spaces.
for _k, _key in zip(short_keys, keys)
]
dirs = [self.__class__.__name__]
if self.directory_per_key:
dirs.extend(key_pairs)
else:
dirs.append(",".join(key_pairs))
return os.path.join(*dirs)
def copy(self):
"""Return a copy of the experiment."""
return self.__class__(**dict(self.items()))
######################################################################
# Parameter access
def get(self, name, t_):
"""Return the value of the parameter `name` at time `t_`.
If the method `name_t_()` exists, it is called, otherwise the attribute
`name` is used.
"""
_meth = self.time_dependent_parameters.get(name)
if _meth is None:
return getattr(self, name)
else:
return _meth(t_)
@implementer(IStateForABMEvolvers)
class StateExample(StateMixin, ArrayStateMixin):
"""Minimial State class that implements the ODE `dy_dt = a*y + b`."""
@property
def t_scale(self):
"""Return the natural time-scale for evolvers. This is used by
simulations to determine the evolver timestep. For quantum problems,
this is usually something like `hbar/E_max` where `E_max` is the largest
energy-scale appearing (for example, `E_max = hbar**2*k_max**2/2/m` where
`k_max = np.pi * Nx/Lx` is the largest momentum representable in the
system.
"""
return 1.0 / self.a
def __init__(self, a, b, experiment=None):
self.a = a
self.b = b
self.experiment = experiment
self.data = np.ones(1, dtype=float)
def compute_dy_dt(self, dy):
if self.experiment is not None:
# Simulates an "imaging procedure". If t is less than t_final,
# we use the specified value of b, otherwise for imaging we set
# b=0.
if self.t <= self.t_final:
b = self.b
else:
b = 0.0
dy.set_data(self.a * self.get_data() + b)
return dy
def answer(self, t=None):
"""Return the analytic solution starting from `y(0) = 1`."""
if t is None:
t = self.t
if self.experiment is not None:
t0 = min(t, self.t_final)
else:
t0 = t
# y1 = (y0 + b/a)*exp(a*(t-t0)) - b/a
a, b = self.a, self.b
y0 = 1.0
y0 = (y0 + b / a) * np.exp(a * (t0 - 0.0)) - b / a
# Now b = 0
y = y0 * np.exp(a * (t - t0))
return y
def plot(self, fig):
from matplotlib import pyplot as plt
plt.plot(self.t, self[0], "o")
return fig
class ExperimentExample(ExperimentBase):
"""Skeleton class demonstrating how to use ExperimentBase.
Fixed experimental parameters that won't change can be specified as
class members. Specifying the parameters here is recommended and has
several advantages over defining an `__init__()` method:
1. They will not be part of the filename unless manually overridden by user
and passed the to constructor.
2. By inheriting from this class, the default parameter values can be
overridden, essentially encapsulated in the new class name.
3. If all parameters are specified as class variables, there is no need
to write the constructor `__init__()`. `ExperimentBase.__init__()` will
allow the user to override parameters if needed.
4. If directory names are desired for these parameters, they can be
explicitly passed to the default constructor and then the appropriate
directories will be created.
Note: If you need to perform any calculations for initialization, do this
in the `init()` method which will be called by the default constructor
`ExperimentBase.__init__()`.
"""
######################################################################
# Attributes required by IExperiment
t_unit = 1.2 # Specifies units for experiment times.
image_ts_ = (10.0,) # Times for imaging (in units of t_unit).
t__image = 1.5 # Imaging for 1.5 units
# End of attributes required by IExperiment
######################################################################
# Experimental parameters
amplitude = 1.2
offset = 3.4
# There is a debugging feature which keeps track of the keys (parameters)
# used during the course of execution to make sure that parameters are not
# mispelled (see issue 1). Sometimes parameters will only conditionally
# used and you want to manually exclude it from the checks. You can do
# this by adding it to your class's _used_keys.
_used_keys = []
######################################################################
# Methods required by IExperiment
def init(self):
"""Initializes the experiment.
All initialization should be done here, not in a constructor.
"""
# For example, one might translate the experimental parameters into
# appropriate units, or perform conversions so that the State object
# can be coded in the most natural set of units, but the experiment can
# be specified in the most physically relevant units.
# The use of a state_args dict would, for example, allow subclasses to
# share `get_state()` and `get_initial_state()` while adding additional
# parameters. However, overloading `get_state()` might be simpler.
self.state_args = dict(a=self.amplitude, b=self.offset)
# Don't forget to call the parent init(). You might want to do this
# first if you need to use the results of the parent's init() method,
# or after if you need to modify parameters before the parent's init()
# is called. (`ExperimentBase.init()` does not do anything, but we
# demonstrate here for good measure.)
super().init()
def get_state(self, initialize=True):
"""Quickly return a valid `State` object."""
return StateExample(experiment=self, **self.state_args)
def get_initial_state(self, _E_tol=1e-12, _psi_tol=1e-12):
"""Return the valid `t=0` state to initialize the simulations."""
# It might be convenient while debugging to play with minimization
# parameters here, but when running and checkpointing, the parameters
# that work should be fixed here and this should be called only without
# arguments.
state = self.get_state()
# Here is how you might go about preparing the initial state. Note:
# rely on `get_state()` to set all parameters - just set the data after
# minimizing. Do not use the state returned by the minimizer
#
# from minimize import MinimizeState
# minimizer = MinimizeState(state, fix_N=True)
# psi0 = minimizer.minimize(E_tol=_E_tol, psi_tol=_psi_tol).get_psi()
# state = self.get_state()
# state.set_psi(psi0)
return state
def get_initialized_state(self, state):
"""Return a valid state initialized from `state`.
This is used in chained simulations where a specified state of one
simulation is used to initialize a state for further use. For example,
for expansion."""
state_ = self.get_state()
state_.set_psi(state.get_psi())
return state_
# End of methods required by IExperiment
######################################################################
class SimulationBase:
"""Manages a simulation, including checkpointing and restarts.
A simulation is a collection of runs associated with an experiment with an
associated collection of frames corresponding to different parameters. These could
be different times during a long-running simulation, or different parameter values
for a scan.
<incomplete>
"""
class Simulation:
"""Manages a simulation, including checkpointing and restarts.
This class requires a State object which can be evolved forward (and
backward) in time. It will then initialize the state and evolve to the
desired time with the checkpoint requirements.
Persistence is managed through checkpoints on disk in the directory
`self.dir_name == os.path.join(self.data_dir, self.experiment.dir_name)`.
This directory will contain the following file:
<dir_name>/experiment.py
<dir_name>/run_*.txt
The first is an archive of the Experiment class which can be used to initialize
the experiment. The second are logs produced when `run()` is called.
The directory may also contain data for the checkpoints. The format
is up to the Experiment class, but will generally be of the form:
<dir_name>/frame_*.npy
The `self.frames` object is an instance of `Frames` with keys that map to
these frames. They can be length one or two tuples. For example:
ikey = (Decimal('0.2000'),) -> frame_0.2000.npy
ikey = (Decimal('0.2000'), Decimal('1.000')) -> frame_0.2000_image_1.000.npy
See `Frames` for a discussion of the distinction between `key` and `ikey`
(internal key).
These simulations will be run, and checkpoints made at a set of
times (controlled by the Simulation class). These checkpoints are
stored as data frames in a directory specified by the name of the
experiment and the parameter values that are different from the
default values. Additionally, one might make images of the state
at various times. When evolving, this is controlled by the
`t__image` parameter which is special (described below).
For example, consider a class `Experiment1` with parameter `a` with different
values `0` and `1.1`, specifying that imaging should take place at
`image_ts_ = (0.2, ...)` after expanding for `t__image = 0.05`, we might run a
`Simulation` with `dt_ = 0.1` ending up with following file-structure:
Experiment1/a=0/frame_0.1000.npy
/frame_0.2000.npy
/frame_0.2000_image_0.0500.npy
/frame_0.3000.npy
...
Experiment1/a=1.1/frame_0.1000.npy
/frame_0.2000.npy
/frame_0.2000_image_0.0500.npy
/frame_0.3000.npy
...
These data are generated in a "run" which is a sequence of times that a state is run
through. Typically we start with a single run which saves intermediate checkpoints
and checkpoints at all the times `experiment.image_ts_` where we will later perform
imaging. Then, a series of runs are performed from these saved `experiment.image_ts_`
up to time `t__final + t__image`. By setting `state.t_final`, the state can change
the potentials to perform imaging for times `t > state.t_final`.
Parameters
----------
experiment : Experiment, None
Experiment instance. Can be None if `dir_name` is specified and
contains an executable `experiment.py` file.
dir_name : str, None
Location of data. Can be `None` if an `experiment` is provided.
dt_ : float
Time interval between frames (in units of `experiment.t_unit`)
dt__image : float, None
Time interval between frames (in units of `experiment.t_unit`) for
imaging (default is to use dt_)
checkpoint_dt_ : float, None
Checkpoint interval (in units of `experiment.t_unit`). If `None`, then `dt_` is
used. During a simulation, checkpoints are made with a minimum step of this
interval -- the actual `dt_` might be larger depending on the underlying step
size (see `dt_t_scale`). When a new checkpoint is made, previous checkpoints are
removed according to the value of `checkpoints_to_retain`.
checkpoints_to_retain : int
Number of checkpoints to retain.
max_t_ : float
Maximum time for simulation (in units of `experiment.t_unit`). Can be `None` if
the experiment lists times `experiment.image_ts_`. If both are provided, then this
flag will be used to limit execution (but a warning will be emitted). If set,
this will also set the upper limit.
extends : (Simulation, t_), None
If provided, then this simulation will call
`self.experiment.get_initialized_state(simulation.get_state(t))` in
order to extend the previous simulation. This is useful when expanding
for example.
data_dir : str
Where to store the results.
checkpoint : bool
If `True`, then checkpoint the results to disk (along with information
about the run as returned by `record_computer_state()`).
image_ts_ : [float]
List of times to perform expansion imaging at. After the simulation is
run to these times, a separate simulation is started which will do the
expansion for imaging. Deprecated - use Experiment.image_ts_
The expansion is only done when the
run_images()
allow_negative_dt : bool
If True, then allow the evolver to evolve backwards from the
nearest state when computing.
mem_limit_bytes : None (default) or int
If not `None`, the maximum space in bytes that data stored in
`self.frames` may take up in memory.
The following parameters are provided in case they are not
specified by the experiment, but the experiment-defined values
will be preferred:
dt_t_scale : float
Timestep to use for evolver (in units of `state.t_scale`) if
not provided by the attribute `experiment.dt_t_scale`.
Evolver : IEvolver
Which time evolver to use if not provided as an attribute
`experiment.Evolver`.
"""
frame_prefix = "frame_"
def __init__(
self,
experiment=None,
dir_name=None,
dt_=1,
dt__image=None,
max_t_=None,
extends=None,
dt_t_scale=0.1,
data_dir=_DATA_DIR,
Evolver=EvolverABM,
logging_level=logging.INFO,
checkpoint=True,
checkpoint_dt_=None,
checkpoints_to_retain=1,
image_ts_=None,
allow_negative_dt=False,
mem_limit_bytes=None,
):
if experiment is None:
if dir_name is None:
raise ValueError("Must provide either and experiment or a dir_name")
else:
experiment_file = os.path.join(dir_name, "experiment.py")
if not os.path.exists(experiment_file):
raise ValueError(
"experiment not provided and couldn't be loaded from {}.".format(
experiment_file
)
)
d = {}
with open(experiment_file) as f:
exec(f.read(), d)
experiment = d["experiment"]
assert dir_name == os.path.join(data_dir, experiment.dir_name)
self._experiment = experiment
self._extends = extends
self._dir_name = dir_name
self.allow_negative_dt = allow_negative_dt
self.data_dir = data_dir
self.max_t_ = max_t_
self.dt_ = dt_
self.dt__image = dt__image
self.checkpoint = checkpoint
self.checkpoint_dt_ = checkpoint_dt_
self.checkpoints_to_retain = checkpoints_to_retain
self.logger = logging.Logger(self.experiment.__class__.__name__)
self.logger.setLevel(logging_level)
self.logger.addHandler(logging.StreamHandler())
self.evolve_times = []
self.mem_limit_bytes = mem_limit_bytes
if image_ts_ is not None:
warnings.warn(
"image_ts_ is deprecated: please use experiment.image_ts_",
DeprecationWarning,
stacklevel=2,
)
self.image_ts_ = image_ts_
self._frames = None
self._Evolver = Evolver
self._dt_t_scale = dt_t_scale
@property
def experiment(self):
if self._experiment is None:
self.initialize()
return self._experiment
@property
def dir_name(self):
if self._dir_name is None:
self.initialize(create_dir=False)
return self._dir_name
def info(self, msg):
self.logger.info(msg)
def warning(self, msg):
self.logger.warning(msg)
def error(self, msg):
self.logger.error(msg)
@property
def Evolver(self):
return getattr(self.experiment, "Evolver", self._Evolver)
@property
def dt_t_scale(self):
return getattr(self.experiment, "dt_t_scale", self._dt_t_scale)
@contextlib.contextmanager
def msg(self, msg):
self.info(msg + "...")
tic = time.time()
try:
yield
except Exception:
self.error(msg + ". Failed!")
raise
toc = time.time() - tic
self.info(msg + ". Done. ({:.2g}s)".format(toc))
def initialize(self, create_dir=True):
"""Initialize the simulation object, and perform sanity checks."""
write_to_disk = create_dir and self.checkpoint
# Check for existing data or make sure dir_name exists.
if self._experiment is not None:
self._dir_name = os.path.join(self.data_dir, self.experiment.dir_name)
if os.path.exists(self.dir_name):
self.warning("Existing simulation directory found: {}".format(self.dir_name))
with self.get_frames(tmp=not write_to_disk) as frames:
self.info("The following frames exist: {}".format(frames.ikeys()))
elif write_to_disk:
with self.msg("Creating simulation directory {}".format(self.dir_name)):
os.makedirs(self.dir_name)
experiment_file = os.path.join(self.dir_name, "experiment.py")
if not os.path.exists(experiment_file):
if self._experiment is None:
raise ValueError(
"No experiment provided or experiment_file={}".format(experiment_file)
)
if write_to_disk:
# Archive the experiment object
with self.msg("Archiving experiment.py"):
archive = Archive(scoped=False)
archive.insert(experiment=self.experiment)
with open(experiment_file, "w") as file:
file.write(str(archive))
# Now load the experiment object to make sure that the archival was
# successful
if os.path.exists(experiment_file):
with self.msg("Loading experiment from experiment.py"):
d = {}
with open(experiment_file) as f:
exec(f.read(), d)
experiment = d["experiment"]
assert self.dir_name == os.path.join(self.data_dir, experiment.dir_name)
if self._experiment is None:
self._experiment = experiment
else:
# Check that experiments are consistent
different_keys = experiment.get_diff(self._experiment)
if different_keys:
raise ValueError(
f"{experiment_file=} inconsistent with self._experiment: "
+ f"{different_keys=}"
)
else:
# This should only happen if not write_to_disk
assert not write_to_disk and self._experiment is not None
experiment = self._experiment
# Check that the state can be created and saved
with self.msg("Creating test state"):
state = self.experiment.get_state(initialize=False)
if write_to_disk:
tmp_prefix = "_tmp" + str(os.getpid()) + self.frame_prefix
args = dict(
data_dir=self.dir_name,
prefix=tmp_prefix,
checkpoints_to_retain=self.checkpoints_to_retain,
)
with self.msg("Saving test state"):
with Frames(mode="w", **args) as frames:
frames[0.0] = state.get_data()
statinfo = os.stat(frames.get_filename(0.0))
frame_size = statinfo.st_size
with self.msg("Loading test state"):
with Frames(**args) as frames:
assert np.allclose(state.get_data(), frames[0.0])
del frames[0.0]
self.frame_size_MB = frame_size / 1024.0**2
self.info("Frame size = {}".format(mem_str(frame_size)))
# Compute number of frames to store and file size
ts_ = self.get_ts_()
Nframes = len(ts_)
if write_to_disk:
self.info(
"All {} frames will take {} of space on disk.".format(
Nframes, mem_str(Nframes * frame_size)
)
)
def get_hostname(self):
"""Return the hostname."""
try:
return subprocess.check_output(["hostname"])
except Exception:
return b"Unknown"
def get_cpuinfo(self):
"""Return info about the CPU."""
# For Linux
try:
return subprocess.check_output(["lsinfo"])
except Exception:
pass
try:
return subprocess.check_output(["cpuinfo"])
except Exception:
pass
# For Mac OS X
try:
return subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
except Exception:
pass
return b"Unknown"
def get_hginfo(self):
"""Return mercurial info about the repo."""
try:
hg1 = subprocess.check_output(["hg", "summary"])
hg2 = subprocess.check_output(["hg", "status"])
return hg1 + hg2
except Exception:
pass
return b"Unknown"
def get_gitinfo(self):
"""Return info about a git repository, if one exists."""
try:
git1 = subprocess.check_output(["git", "show", "-s"])
git2 = subprocess.check_output(["git", "status"])
return (
b"HEAD (`git show -s`):\n" + git1 + b"\nStatus (`git status`):\n" + git2
)
except Exception:
pass
return b"Unknown"
def get_jjinfo(self):
"""Return info about a JJ (Jujutsu) repository, if one exists."""
try:
jj1 = subprocess.check_output(["jj", "show", "-s"])
jj2 = subprocess.check_output(["jj", "show", "-s", "@-"])
return (
b"Working copy (`jj show -s`):\n"
+ jj1
+ b"\nParent (`jj show -s @-`):\n"
+ jj2
)
except Exception:
pass
return b"Unknown"
def get_time(self):
"""Return the time in a string that can be used in a filename."""
return time.strftime("%Y%m%d-%H%M%S")
def get_computer_state(self, timestr=None):
"""Record the state of the computer, repository, etc. before starting
the run.
"""
hostname = self.get_hostname().decode()
cpuinfo = self.get_cpuinfo().decode()
hginfo = self.get_hginfo().decode()
gitinfo = self.get_gitinfo().decode()
jjinfo = self.get_jjinfo().decode()
if timestr is None:
timestr = self.get_time()
computer_state = [
"Starting run at time {}".format(timestr),
"sys.argv = {}".format(sys.argv),
"hostname = {}".format(hostname),
"cpuinfo",
"=======",
cpuinfo,
"",
"hginfo",
"======",
hginfo,
"",
"gitinfo",
"=======",
gitinfo,
"",
"jjinfo",
"======",
jjinfo,
]
return "\n".join(computer_state)
@staticmethod
def get_human_duration(seconds):
"""Return a nice human-readable duration"""
seconds = int(np.ravel(seconds)[0])
res = []
for msg, seconds_per in [
("d", 24 * 60 * 60),
("h", 60 * 60),
("m", 60),
("s", 1),
]:
if seconds > seconds_per:
res.append("{}{}".format(seconds // seconds_per, msg))
seconds = seconds % seconds_per
return ", ".join(res)
def get_frames(self, tmp=False):
"""Return the `Frames` instance.
Arguments
---------
tmp : bool
If `True`, then return a temporary instance for accessing properties like
Decimal.
"""
if self._frames is None:
args = dict(
data_dir=self.dir_name,
prefix=self.frame_prefix,
checkpoints_to_retain=self.checkpoints_to_retain,
mem_limit_bytes=self.mem_limit_bytes,
)
frames = Frames(**args)
if tmp:
return frames
self._frames = frames
frames = self._frames
if self.checkpoint:
if not os.path.exists(self.dir_name):
self.initialize(create_dir=True)
frames.mode = "rw"
elif os.path.exists(self.dir_name):
frames.mode = "rm"
else:
frames.mode = "m"
return frames
frames = property(get_frames)
@property
def Decimal(self):
return self.get_frames(tmp=True).Decimal
def delete_data(self):
"""Remove all saved data from disk."""
if os.path.exists(self.dir_name):
shutil.rmtree(self.dir_name)
dir_name = os.path.dirname(self.dir_name)
while dir_name and not os.listdir(dir_name):
# Remove containing folders if they are empty
os.rmdir(dir_name)
dir_name = os.path.dirname(dir_name)
def _unique(self, ts_, max_t_=None, ikeys=None):
"""Return a sorted array of unique times <= max_t_."""
if max_t_ is None:
max_t_ = self.max_t_
Decimal = self.Decimal
times_ = list(map(Decimal, ts_))
if times_ and max_t_ is not None and max_t_ < max(times_):
self.warning(f"Some times exceed {max_t_=}. Discarding.")
times_ = [t for t in times_ if t <= Decimal(max_t_)]
if ikeys is None:
return np.unique(times_).astype(float)
else:
times, inds = np.unique(times_, return_index=True)
ikeys = [ikeys[i] for i in inds]
return times.astype(float), ikeys
def get_image_ts_(self):
"""Return the sorted list of times at which imaging will be run."""
image_ts_ = list(self.experiment.image_ts_)
if self.image_ts_ is not None:
image_ts_.extend(self.image_ts_)
return self._unique(image_ts_)
def get_ts_(self, image_t_=None):
"""Return a sorted list of times at which we want simulation data.
Arguments
---------
image_t_ : float, None
If `None`, then this is the list of times without imaging. Otherwise, it is
the list of times up to `image_t_ + experiment.t__image` for the imaging run
for `image_t_`
"""
Decimal = self.Decimal
if image_t_ is None:
max_t_ = self.max_t_
if max_t_ is None:
max_t_ = 0
image_ts_ = self.get_image_ts_().tolist()
if image_ts_:
max_t_ = max(max(image_ts_), max_t_)
else:
image_ts_ = [0]
T_ = Decimal(max_t_)
dt_ = Decimal(self.dt_)
N = int(np.ceil(T_ / dt_))
# We don't use np.arange because we want the last element if it is <= T_
ts_ = [n * dt_ for n in range(N + 1) if n * dt_ <= T_]
# Explicitly include the image times so that we can directly evolve up to
# these without restarting. This can improve the accuracy slightly with the
# ABM method which introduces a small error on restarts.
ts_.extend(image_ts_)
else:
max_t_ = image_t_ + self.experiment.t__image
T_ = Decimal(max_t_)
dt__image = self.dt__image
if dt__image is None:
dt__image = self.dt_
ts_ = (
np.arange(Decimal(0), Decimal(image_t_), Decimal(self.dt_)).tolist()
+ np.arange(Decimal(image_t_), T_, Decimal(dt__image)).tolist()
)
ts_.append(T_)
return self._unique(ts_, max_t_=max_t_)
def get_saved_ts_(self, image_t_=None, t__image=None):
"""Return `(ts_, ikeys)`, lists of times and keys with simulation data.
Arguments
---------
t__image : None, float
If `None`, then return the times along a saved solution path, otherwise
return all states that have been imaged for this length of time.
image_t_ : None, float
If `None`, return the times along the saved solution path without any
imaging, otherwise return the saved solution path with imaging starting at
`image_t_`. Not used if `t__image` is specified.
Returns
-------
ts_ : list(float)
Sorted array of unique times.
ikeys : list(ikey)
Corresponding list of ikeys.
"""
with self.frames as frames:
Decimal = self.Decimal
max_t_ = self.max_t_
if max_t_ is None:
max_t_ = np.inf
else:
max_t_ = Decimal(max_t_)
if t__image is not None:
t__image = Decimal(t__image)
ikeys = [
ikey
for ikey in frames.ikeys()
if len(ikey) > 1 and ikey[1] == t__image and ikey[0] <= max_t_
]
ts_ = [self.get_t_(ikey=ikey, frames=frames) for ikey in ikeys]
max_t_ = max(ts_)
elif image_t_ is None:
ikeys = [ikey for ikey in frames.ikeys() if len(ikey) == 1]
ts_ = [self.get_t_(ikey=ikey, frames=frames) for ikey in ikeys]
max_t_ = self.max_t_
else:
ts_ = []
ikeys = []
for ikey in frames.ikeys():
t_ = self.get_t_(ikey=ikey, frames=frames)
if (len(ikey) > 1 and ikey[0] == Decimal(image_t_)) or (
len(ikey) == 1 and Decimal(t_) <= Decimal(image_t_)
):
ts_.append(t_)
ikeys.append(ikey)
if ts_:
max_t_ = max(ts_)
return self._unique(ts_, ikeys=ikeys, max_t_=max_t_)
def get_wanted_ts_(self, image_t_=None):
"""Return a list of wanted ts."""
ts_ = self.get_ts_(image_t_=image_t_)
saved_ts_ = set(self.get_saved_ts_(image_t_=image_t_)[0])
return [t_ for t_ in ts_ if t_ not in saved_ts_]
def get_solution_path(self, image_t_=None):
"""Return the solution path.
This is the optimal ordering of times to minimize execution assuming that the
time to evolve is proportional to the difference in times.
"""
saved_ts_, ikeys = self.get_saved_ts_(image_t_=image_t_)
solution_path = []
computed_ts_ = list(saved_ts_)
wanted_ts_ = self.get_wanted_ts_(image_t_=image_t_)
while wanted_ts_:
dts_ = np.asarray(wanted_ts_)[:, None] - np.asarray(computed_ts_)[None, :]
# Find the desired time that is closest to a computed state
i = abs(dts_).min(axis=1).argmin()
t_ = wanted_ts_[i]
# Find the closest computed state
if self.allow_negative_dt:
raise NotImplementedError()
# Don't do this until tested! computation of steps
# below fails.
i0 = abs(dts_)[i, :].argmin()
t0_ = computed_ts_[i0]
else:
_dts = dts_[i, :]
_dts = np.where(_dts < 0, np.inf, _dts)
i0 = _dts.argmin()
t0_ = computed_ts_[i0]
solution_path.append((t0_, t_))
del wanted_ts_[i]
computed_ts_.append(t_)
return solution_path
def get_t_(self, ikey, frames):
"""Return `t_` for the corresponding `ikey` from `frames`."""
return float(sum(ikey))
def get_ikey(self, t_, image_t_=None):
if image_t_ is None or t_ <= image_t_:
key = t_
else:
key = (image_t_, t_ - image_t_)
return self.get_frames(tmp=True).key_to_ikey(key)
def set_frame(self, t_, data, image_t_, checkpoint=False):
"""Set the frame data and return key."""
with self.frames as frames:
ikey = self.get_ikey(t_=t_, image_t_=image_t_)
if checkpoint:
frames.checkpoint(ikey, data)
else:
frames[ikey] = data
return ikey
def get_state(self, t_=None, t__image=False, compute=False):
"""Return the state at time t_ (if it has been computed).
Arguments
---------
t_ : float
Time specifying which state to get.
t__image : float, bool
If provided, then load the image associated with state after time
`t__image`. If simply `True`, then use the `t__image=self.experiment.t__image`.
compute : bool
If `True` then compute the state from the nearest saves state.
"""
if t__image:
if t__image is True:
t__image = self.experiment.t__image
image_t_ = t_
t_ = image_t_ + t__image
else:
image_t_ = None
ikey = self.get_ikey(t_=t_, image_t_=image_t_)
return self.get_state_from_ikey(ikey=ikey, image_t_=image_t_, compute=compute)
def get_state_from_ikey(self, ikey, image_t_=None, compute=False, check_exists=True):
"""Return the state with key (if it has been computed).
Arguments
---------
ikey : ikey
image_t_ : float
Imaging time. This controls `state.t_final`. If `len(ikey) > 1`
(i.e. you are requesting a state after some imaging),
then `image_t_=ikey[0]`.
compute : bool
If `True` then compute the state from the nearest saves state.
check_exists : bool
Check that the state at `ikey` exists, potentially computing it if
not. Set this to `False` if you already know the state at `ikey`
exists on disk.
"""
if len(ikey) > 1:
ikey_image_t_ = ikey[:1]
if (
image_t_ is not None
and self.frames.key_to_ikey(image_t_) != ikey_image_t_
):
raise ValueError(
f"Imaging time from ikey (={ikey_image_t_}) is different than `image_t_`(={image_t_})"
)
image_t_ = float(ikey_image_t_[0])
if not check_exists or ikey in self.get_saved_ts_(image_t_=image_t_)[1]:
with self.frames as frames:
t_ = self.get_t_(ikey, frames=frames)
data = frames[ikey]
if image_t_ is None:
t__final = np.inf
else:
t__final = image_t_
state = self.experiment.get_state(initialize=False)
state.t_final = float(t__final) * self.experiment.t_unit
state.set_data(data)
state.t = t_ * self.experiment.t_unit
state.initializing = False
return state
elif not compute:
raise IndexError(f"No data for state with {ikey=}.")
else:
raise NotImplementedError
def get_states(self, image_t_=None, t__image=None, progress=False):
"""Return a list of computed states.
Arguments
---------
image_t_ : float
If provided, then return all states up to this time and all
images after this. Otherwise, return all non-imaging
states.
t__image : float
If provided, then get all states that have been imaged
after expanding to this time.
progress : bool, float
If provided, then print progress messages every progress seconds.
"""
ts_, ikeys = self.get_saved_ts_(image_t_=image_t_, t__image=t__image)
if progress is not None:
msg = "About to load {} states".format(len(ikeys))
with self.frames as frames:
image_ikeys = [ikey for ikey in ikeys if len(ikey) > 1]
if image_t_ is not None:
msg = " ".join([msg, f"({len(image_ikeys)} during imaging)"])
self.msg(msg)
states = []
tic = time.time()
tic_last_message = 0
for _n, ikey in enumerate(ikeys):
state = self.get_state_from_ikey(
ikey=ikey, image_t_=image_t_, check_exists=False
)
states.append(state)
if progress and time.time() > tic_last_message + progress:
toc = tic_last_message = time.time()
time_per_state = (toc - tic) / (_n + 1)
remaining_states = len(ikeys) - len(states)
time_remaining = time_per_state * remaining_states
msg = "... {:.1f}s to go ({} states at {:.1f}s/state)".format(
time_remaining, remaining_states, time_per_state
)
self.msg(msg)
print(msg)
return states
def run(self, image_t_=None):
"""Run the simulation, optionally checkpointing the data to disk.
Arguments
---------
image_t_ : float, None
If provided, then run the imaging procedure for this time.
I.e. we will set `state.t_final = image_t` and then evolve up to
`image_t_ + experiment.t__image`. These results will be stored with
`key=(image_t, t_)` where `t_` varies in steps of `dt_` up to
`experiment.t__image`.
"""
self.initialize()
timestr = self.get_time()
computer_state = self.get_computer_state(timestr=timestr)
if self.checkpoint:
if image_t_ is None:
info_file = f"run_{timestr}.txt"
else:
info_file = f"run_{self.frames.Decimal(image_t_)}_image_{timestr}.txt"
info_file = os.path.join(self.dir_name, info_file)
with self.msg(f"Writing {info_file=}"):
with open(info_file, "w") as f:
f.write(computer_state)
else:
self.info(computer_state)
saved_ts_, saved_ikeys = self.get_saved_ts_(image_t_=image_t_)
saved = dict(zip(saved_ikeys, saved_ts_))
t_unit = self.experiment.t_unit
if not saved:
# No existing data. Find initial state
with self.msg("Finding initial state"):
if self._extends is None:
with NoInterrupt():
state = self.experiment.get_initial_state()
else:
simulation, t_ = self._extends
state = self.experiment.get_initialized_state(
state=simulation.get_state(
t_=t_, image_t_=image_t_
) # REV: t__image
)
t_ = state.t / t_unit
with self.msg("Saving initial state (t_={})".format(t_)):
self.set_frame(t_=t_, data=state.get_data().copy(), image_t_=image_t_)
ikey = self.get_ikey(t_=t_, image_t_=image_t_)
saved[ikey] = t_
else:
state = self.experiment.get_state()
solution_path = self.get_solution_path(image_t_=image_t_)
dt = self.dt_t_scale * state.t_scale
dt_ = dt / t_unit
self.info(
"Solution path: {}".format(
" ".join(
[
"{}->{}".format(
self.get_ikey(t_=_t0_, image_t_=image_t_),
self.get_ikey(t_=_t_, image_t_=image_t_),
)
for (_t0_, _t_) in solution_path
]
)
)
)
with NoInterrupt(ignore=True) as interrupted:
while solution_path and not interrupted:
if self.evolve_times:
# Estimate remaining time
dts_ = np.diff(solution_path, axis=1)
steps = sum(np.ceil(dts_ / dt_).astype(int))
tics_, steps_, times_ = np.asarray(self.evolve_times).T
time_per_step = times_ / steps_
time_remaining = steps * time_per_step.mean()
time_remaining_err = steps * time_per_step.std()
self.info(
"Estimated time to completion: {}+-{}".format(
self.get_human_duration(time_remaining),
self.get_human_duration(time_remaining_err),
)
)
t0_, t_ = solution_path.pop(0)
state = self.compute_state(t_=t_, image_t_=image_t_)
with self.msg("Saving state t_={}".format(t_)):
saved[
self.set_frame(
t_=t_, data=state.get_data().copy(), image_t_=image_t_
)
] = t_
def compute_state(self, t_, image_t_=None):
"""Compute the state at the specified time from the nearest existing
state, checkpointing if needed.
"""
saved_ts_, ikeys = self.get_saved_ts_(image_t_=image_t_)
ikey = self.get_ikey(t_=t_, image_t_=image_t_)
if ikey in ikeys:
return self.get_state_from_ikey(ikey=ikey, image_t_=image_t_)
# Find the closest existing state
dts_ = t_ - np.asarray(saved_ts_)
if self.allow_negative_dt:
raise NotImplementedError()
# Don't do this until tested! computation of steps
# below fails.
dts_ = abs(dts_)
else:
dts_ = np.where(dts_ < 0, np.inf, dts_)
i0 = dts_.argmin()
t0_, ikey0 = saved_ts_[i0], ikeys[i0]
state = self.get_state_from_ikey(ikey=ikey0, image_t_=image_t_)
t_unit = self.experiment.t_unit
dt = self.dt_t_scale * state.t_scale
dt_ = dt / t_unit
with self.msg("Evolving from t0_={} to t_={}".format(t0_, t_)):
# Evolvers need at least 2 steps
if not self.allow_negative_dt:
assert t_ > t0_
steps = max(2, int(np.ceil((t_ - t0_) / dt_)))
dt_ = (t_ - t0_) / steps
dt = dt_ * t_unit
evolver = self.Evolver(state, dt=dt)
if getattr(evolver, "fixed_dt", True):
# Break evolution up for checkpoints.
if self.checkpoint_dt_ is not None and self.checkpoint_dt_ < (t_ - t0_):
checkpoint_steps = max(2, int(np.floor(self.checkpoint_dt_ / dt_)))
stepss = [checkpoint_steps] * (steps // checkpoint_steps)
last_steps = steps % checkpoint_steps
if last_steps < 2:
stepss[-1] += last_steps
else:
stepss.append(last_steps)
assert min(stepss) >= 2
stepss, last_steps = stepss[:-1], stepss[-1]
assert sum(stepss) + last_steps == steps
else:
stepss = []
last_steps = steps
tic = time.time()
for _steps in stepss:
evolver.evolve(_steps)
self.set_frame(
t_=evolver.t / t_unit,
##### REV Should use get_psi() and set_psi() here.
data=evolver.y[...].copy(),
checkpoint=True,
image_t_=image_t_,
)
assert last_steps > 0
evolver.evolve(last_steps)
else:
# Break evolution up for checkpoints.
if self.checkpoint_dt_ is not None and self.checkpoint_dt_ < (t_ - t0_):
checkpoint_ts_ = np.arange(t0_, t_, self.checkpoint_dt_)
if checkpoint_ts_[-1] >= t_:
checkpoint_ts_ = checkpoint_ts_[:-1]
else:
checkpoint_ts_ = []
tic = time.time()
for _t_ in checkpoint_ts_:
_t = _t_ * t_unit
evolver.evolve_to(_t)
assert np.allclose(_t, evolver.t)
self.set_frame(
t_=_t_,
##### REV Should use get_psi() and set_psi() here.
data=evolver.y[...].copy(),
checkpoint=True,
image_t_=image_t_,
)
evolver.evolve_to(t_ * t_unit)
self.evolve_times.append((tic, steps, time.time() - tic))
assert np.allclose(evolver.t, t_ * t_unit)
return evolver.get_y()
def run_images(self):
"""Make the images from the simulation."""
image_ts_ = self.get_image_ts_()
if image_ts_ is None:
self.warning(
"No image_ts_ specified (got {}). Doing nothing.".format(self.image_ts_)
)
return
self.initialize()
for image_t_ in image_ts_:
self.run(image_t_=image_t_)
def view(self, plot_state=None, t__image=None, **kw):
"""Run this to load the generated data and plot it.
Arguments
---------
plot_state(state):
"""
from IPython.display import display, clear_output
from matplotlib import pyplot as plt
if plot_state is None:
def plot_state(state, fig=None):
try:
return state.plot(fig=fig, **kw)
except TypeError:
if fig is not None:
fig = plt.figure(fig.number)
else:
fig = plt.gcf()
state.plot()
return fig
# t_unit = self.experiment.t_unit
# state = self.experiment.get_state()
fig = None
with NoInterrupt(ignore=True) as interrupted:
saved_ts_, ikeys = self.get_saved_ts_(t__image=t__image)
for t_ in saved_ts_:
plt.clf()
fig = plot_state(self.get_state(t_=t_, t__image=t__image), fig=fig)
plt.draw()
plt.pause(0.01)
display(fig)
if interrupted:
break
clear_output(wait=True)
class SimulationManager:
"""Provides access to a bunch of simulations stored on disk. Each set of
simulations is accessible through an attribute with the same name as the
experiment class. This attribute will return a Simulations() instance with
the corresponding simulations.
"""
def __init__(self, data_dir, simulations=None):
self._data_dir = data_dir
def find_simulations(self):
self._simulation_dirs = simulation_dirs = [
dir_name
for dir_name, dirnames, filenames in os.walk(self._data_dir)
if "experiment.py" in filenames
]
return simulation_dirs
def load_simulations(self, simulation_dirs=None):
if simulation_dirs is None:
simulation_dirs = self.find_simulations()
simulations = []
for dir_name in simulation_dirs:
try:
simulation = Simulation(dir_name=dir_name)
simulations.append(simulation)
except Exception:
traceback.print_exc()
self._simulations = Simulations(simulations)
names = dict()
for s in self._simulations:
name = s.experiment.__class__.__name__
names.setdefault(name, []).append(s)
for name in names:
self.__dict__[name] = Simulations(names[name])
class Simulations(abc.Sequence):
"""Represents a list of simulations where one can access the simulations
through attribute access with tab completion.
"""
def __init__(self, simulations):
self._simulations = simulations
self._keys = {}
for s in simulations:
for key in s.experiment._keys:
value = getattr(s.experiment, key)
self._keys.setdefault(key, set()).add(value)
def __dir__(self):
"""Custom attribute access for tab completion."""
keys = sorted(_k for _k in self._keys if len(self._keys[_k]) > 1)
res = []
for key in keys:
new = ["{}[{}]".format(key, value) for value in self._keys[key]]
if res:
res.extend(["{}.{}".format(_res, _new) for _res in res for _new in new])
else:
res.extend(new)
return res
def __getattr__(self, key):
return {value: self.get(**{key: value}) for value in self._keys[key]}
def __len__(self):
return len(self._simulations)
def __getitem__(self, key):
return self._simulations[key]
def keys(self):
keys = set()
for s in self.simulations:
keys.update(s.experiment._keys)
return sorted(keys)
def get(self, **kw):
"""Return the simulations satisfying the specified criteria."""
results = []
for s in self._simulations:
if all([getattr(s.experiment, _k) == kw[_k] for _k in kw]):
results.append(s)
return Simulations(results)
def __repr__(self):
return "{}():\n{}".format(
self.__class__.__name__,
"\n".join([_s.dir_name for _s in self._simulations]),
)
# try:
#
# from . import visualize
#
# class Simulation2(visualize.SimulationMixin2, Simulation):
# pass
#
# except ImportError:
# pass
######################################################################
# GPU Support
def i_know_this_is_slow(func):
"""Decorator to suppress PerformanceWarning."""
func._i_know_this_is_slow = True
return func
class _GPU:
"""Various decorators for helping with GPU and other accelerator support."""
@staticmethod
def _asnumpy(result, instance=None):
"""Convert result to a numpy array or a tuple of arrays.
Uses the following in order:
* `instance.asnumpy(result)` if provided
* `result.get()`
* `cupy.asnumpy(result)`
"""
if instance:
assert hasattr(instance, "asnumpy")
asnumpy = instance.asnumpy
try:
if isinstance(result, tuple):
return tuple(map(asnumpy, result))
else:
return instance.asnumpy(result)
except ValueError:
# Special case triggered if the result is a list of inhomogeneous arrays.
# Perhaps we should force the user to return a tuple instead? We enforce
# this in our code tests.
if "PYTEST_CURRENT_TEST" in os.environ:
raise
try:
return tuple(map(asnumpy, result))
except ValueError:
pass
except AttributeError:
pass
try:
if isinstance(result, tuple):
return tuple(_r.get() for _r in result)
else:
return result.get()
except AttributeError:
pass
try:
from cupy import asnumpy
except ImportError:
from numpy import asarray as asnumpy
if isinstance(result, tuple):
return tuple(map(asnumpy, result))
else:
return asnumpy(result)
@wrapt.decorator
@classmethod
def _return_asnumpy(_GPU, wrapped, instance, args, kwargs):
"""Decorate the specified function to return a numpy array.
To be used on functions that return GPU arrays for performance to
generate a method for users to get these as numpy arrays.
Uses the following in order:
* `state.asynumpy(result)` if 'state' in kwargs
* `self.asnumpy(result)` if provided
* `result.get()`
* `cupy.asnumpy(result)`
"""
result = wrapped(*args, **kwargs)
if hasattr(kwargs.get("state", None), "asnumpy"):
instance = kwargs["state"]
elif not hasattr(instance, "asnumpy"):
instance = None
return _GPU._asnumpy(result, instance=instance)
@classmethod
def from_GPU(_GPU, method_GPU):
"""Decorator for _GPU methods to fetch the result as a numpy array."""
wrapped = _GPU._return_asnumpy(method_GPU)
wrapped._GPU = True
return wrapped
@classmethod
def GPU_delegate(_GPU, method):
"""Decorator for GPU methods to delegate to the non-GPU method."""
wrapped = _GPU._performance_warning(method)
return wrapped
@wrapt.decorator
@classmethod
def _performance_warning(_GPU, wrapped, instance, args, kwargs):
"""Decorate the specified function warning of performance."""
if not hasattr(wrapped, "_i_know_this_is_slow"):
warnings.warn(
f"""Default {wrapped.__qualname__}_GPU() delegates to non-GPU version.
This is a potential performance issue, even if not using the GPU. To ensure optimal
performance, you should make sure that you overload {wrapped.__name__}_GPU() instead.
""",
category=PerformanceWarning,
)
return wrapped(*args, **kwargs)
@classmethod
def noop_GPU(_GPU, method_GPU):
"""Decorator to protect _GPU methods from being converted."""
wrapped = _GPU._return_asnumpy(method_GPU)
wrapped._GPU = True
wrapped._noop_GPU = True
return wrapped
@staticmethod
def default(method_GPU):
"Mark with `._default=True` denoting that it delegates to the non-GPU version."
method_GPU._default = True
return method_GPU
@classmethod
def add_non_GPU_methods(_GPU, cls):
"""Decorator that defines non-GPU methods for cls from _GPU methods.
Generally, it is an error to define both the non-gpu and GPU methods, and this
will raise an AttributeError in this case. The following exceptions are
provided.
1. If the method has the attribute `._default = True` (which is set with the
`_GPU.default` decorator), then we allow the user to define the non-GPU
version even though the GPU version is defined because the default GPU
versions delegate to the non-GPU versions.
"""
warnings.warn(
"_GPU.add_non_GPU_methods. Inherit from GPUHelper instead.",
DeprecationWarning,
)
# Don't dig into subclasses.
# CHECK: This is a little suspect, at least with staticmethods.
# getattr(cls, name) is the function, but cls.__dict__[name] is the
# staticmethod wrapped function. Can we safely wrap a staticmethod, or
# do we need to wrap the function, then wrap with staticmethod outside.
# The tests work... but should check wrapt docs.
names = {name for name in cls.__dict__ if inspect.isfunction(getattr(cls, name))}
members = {name: cls.__dict__[name] for name in names}
GPU_names = set(name for name in members if name.endswith("_GPU"))
non_GPU_names = set(_name[:-4] for _name in GPU_names)
duplicates = set(
name
for name in non_GPU_names.intersection(names)
if (
name in members
and not hasattr(members[name], "_GPU")
and not getattr(members[name], "_default", False)
)
)
if duplicates:
# It is an error to define both a _GPU and a non_GPU method. This avoids
# accidentally overloading the regular method when the _GPU method should be
# overloaded, but allows classes with no _GPU methods to still function
# normally.
raise AttributeError(
f"Class {cls.__name__} has _GPU methods but also defines {duplicates}"
)
methods_to_add = non_GPU_names.difference(names)
for name in methods_to_add:
setattr(cls, name, _GPU.from_GPU(members[f"{name}_GPU"]))
return cls
@classmethod
def add_GPU_and_non_GPU_methods(_GPU, cls):
"""Decorator that defines the GPU/non-GPU pairs for cls from _GPU methods.
Generally, it is an error to define both the non-gpu and GPU methods, and this
will raise an AttributeError in this case. The following exceptions are
provided.
1. If the method has the attribute `._default = True` (which is set with the
`_GPU.default` decorator), then we allow the user to define the non-GPU
version even though the GPU version is defined because the default GPU
versions delegate to the non-GPU versions.
"""
# CHECK: This is a little suspect, at least with staticmethods.
# getattr(cls, name) is the function, but cls.__dict__[name] is the
# staticmethod wrapped function. Can we safely wrap a staticmethod, or
# do we need to wrap the function, then wrap with staticmethod outside.
# The tests work... but should check wrapt docs.
names = {name for name in cls.__dict__ if inspect.isfunction(getattr(cls, name))}
members = {name: cls.__dict__[name] for name in names}
# These are the *_GPU methods defined by the user. These need non-GPU methods
# defined.
GPU_names = set(name for name in members if name.endswith("_GPU"))
non_GPU_names_needed = set(_name[:-4] for _name in GPU_names)
# These are the non-GPU methods defined by the user where there is a _GPU method
# somewhere. These need GPU methods defined.
non_GPU_names = set(
name
for name in members
if not name.endswith("_GPU") and hasattr(cls, name + "_GPU")
)
GPU_names_needed = set(name + "_GPU" for name in non_GPU_names)
duplicates = set(
name
for name in non_GPU_names_needed.union(GPU_names_needed).intersection(names)
if (
name in members
and not hasattr(members[name], "_GPU")
and not getattr(members[name], "_default", False)
)
)
if duplicates:
# It is an error to define both a _GPU and a non_GPU method. This avoids
# accidentally overloading the regular method when the _GPU method should be
# overloaded, but allows classes with no _GPU methods to still function
# normally.
raise AttributeError(
f"Class {cls.__name__} defines both non-GPU and GPU methods {sorted(duplicates)}"
)
for name in non_GPU_names_needed.difference(names):
setattr(cls, name, _GPU.from_GPU(members[f"{name}_GPU"]))
for GPU_name in GPU_names_needed.difference(names):
setattr(cls, GPU_name, _GPU.GPU_delegate(members[f"{GPU_name[:-4]}"]))
return cls
@staticmethod
def _check_class(self):
"""Check the class for consistency."""
bad = {
method.__qualname__: method
for name, method in inspect.getmembers(self)
if name.endswith("_GPU") and not getattr(method, "_GPU", False)
}
names = list(bad)
classes = set([name.rsplit(".", 1)[0] for name in names])
if bad:
cls = self.__class__.__qualname__
raise ValueError(
"\n".join(
[
f"{names} uninitialized for GPU in {cls}.",
"Did you forget @_GPU.add_non_GPU_methods? I.e.:\n",
]
+ [f"@_GPU.add_non_GPU_methods\nclass {cls}...\n" for cls in classes]
)
)
[docs]
class GPUHelper:
"""Class for helping with GPU and other accelerator support.
Examples
--------
>>> class B(GPUHelper):
... def asnumpy(self, result):
... return result + ['from B.asnumpy'] # Simulate a custom asnumpy method
>>> class A(GPUHelper):
... def asnumpy(self, result):
... return result + ['from A.asnumpy'] # Simulate a custom asnumpy method
...
... @staticmethod
... def get_list1_GPU():
... "Return a list. Non-GPU method will return an array."
... return [1, 2]
...
... def get_list2_GPU(self):
... "Return a list. Non-GPU will use self.asnumpy."
... return [1, 2]
...
... def get_list3_GPU(self, state):
... "Return a list. None-GPU will use state.asnumpy."
... return [1, 2]
>>> a = A()
>>> a.get_list1_GPU() # GPU method returns list (representing GPU array)
[1, 2]
>>> a.get_list2_GPU()
[1, 2]
>>> a.get_list3_GPU(state=B())
[1, 2]
>>> a.get_list1() # Non GPU method returns an array.
array([1, 2])
>>> a.get_list2() # Non GPU instance method uses A.asnumpy.
[1, 2, 'from A.asnumpy']
>>> a.get_list3(state=B()) # Non GPU instance method uses B.asnumpy.
[1, 2, 'from B.asnumpy']
>>> set(['get_list1', 'get_list2', 'get_list3']).issubset(dir(a))
True
"""
[docs]
def __init_subclass__(cls):
"""Prepare the class by defining all non-GPU/GPU pairs "*" and "*_GPU"."""
super().__init_subclass__()
_GPU.add_GPU_and_non_GPU_methods(cls)
[docs]
class AsNumpyMixin(StateMixin, GPUHelper):
"""Mixin providing asnumpy method."""
[docs]
asnumpy = staticmethod(np.asarray)
[docs]
def get_data_GPU(self):
"""Partner to set_data().
This allows access to the data via self.get_data() which is
guaranteed to be a number array.
"""
return self.data[...]
[docs]
class StateWithExperimentMixin(StateMixin, GPUHelper):
"""Mixing to delegate to self.experiment."""
def __init__(self, experiment, **kw):
[docs]
self.experiment = experiment
return super().__init__(**kw)
[docs]
def get_Vext_GPU(self):
return self.experiment.get_Vext_GPU(state=self)
[docs]
def get_Vint_GPU(self):
return self.experiment.get_Vint_GPU(state=self)
[docs]
def get_Eint(self):
return self.experiment.get_Eint(state=self)
[docs]
def get_n_TF(self, V_TF):
return self.experiment.get_n_TF(state=self, V_TF=V_TF)
[docs]
def get_ns_TF(self, Vs_TF):
return self.experiment.get_ns_TF(state=self, Vs_TF=Vs_TF)