from __future__ import division, print_function, with_statement
import numpy as np
import xarray as xr
try:
import holoviews as hv
import param, paramnb
from holoviews.operation.datashader import regrid # , datashade
except (ImportError, RuntimeError):
[docs]
class SimulationMixin2:
"""Mixin for Simulation classes that provides holoview visualization."""
[docs]
opts = dict(width=600, height=300)
@property
[docs]
def psis(self):
self._set_states()
return self._psis
@property
[docs]
def phases(self):
self._set_states()
return np.angle(self._psis)
@property
[docs]
def psis_k(self):
self._set_states()
return self._psis_k
@property
[docs]
def ns(self):
self._set_states()
return self._ns
[docs]
def _set_states(self, image=None):
"""Load all of the states and return (ts, psis)."""
if not hasattr(self, "_states") or (image is not None and self._image != image):
if image:
keys = sorted(
[
key
for key in self.frames.keys()
if self.frames.isiterable(key)
and key[1] == self.experiment.t__image
]
)
ts_ = np.array([key[0] for key in keys])
else:
ts_ = [_t for _t in self.ts_ if _t in self.saved_ts_]
states = [self.get_state(t_, image=image) for t_ in ts_]
psis = np.asarray([_state[...] for _state in states])
psis_k = np.fft.fftshift(np.fft.fft(psis, axis=-1), axes=[-1])
if image:
ns = np.asarray(
[_state.experiment.simulate_image(_state)[1] for _state in states]
)
xs = _state.experiment.simulate_image(states[0])[0]
else:
ns = np.asarray([_state.get_density() for _state in states])
xs = states[0].xyz[0].ravel()
self._xs = xs
self._ts_ = ts_
self._states = states
self._psis = psis
self._psis_k = psis_k
self._ns = ns
self._image = image
@property
[docs]
def data(self, skip=1):
"""Return an xarray with the density data for plotting."""
data = xr.DataArray(
self.ns[:, :, ::skip],
name="n",
dims=("t", "species", "x"),
coords=dict(species=["a", "b"], x=self._xs[::skip], t=self._ts_),
).transpose("species", "t", "x")
return data
@property
[docs]
def data_phase(self, skip=1):
"""Return an xarray with the density data for plotting."""
data = xr.DataArray(
self.phases[:, :, ::skip],
name="phase",
dims=("t", "species", "x"),
coords=dict(species=["a", "b"], x=self._xs[::skip], t=self._ts_),
).transpose("species", "t", "x")
return data
@property
[docs]
def data_k(self, skip=1):
"""Return an xarray with the momentum data for plotting."""
data = xr.DataArray(
abs(self.psis_k[:, :, ::skip]) ** 2,
name="n_k",
dims=("t", "species", "k"),
coords=dict(
species=["a", "b"],
k=np.fft.fftshift(self._states[0].kxyz[0][0].ravel()[::skip]),
t=self._ts_,
),
).transpose("species", "t", "k")
return data
[docs]
def get_densities(self, t, normalize):
"""Event handler that returns the density curves."""
data = self.data
curves = []
n_min = 0
n_max = 0
for _i, species in enumerate(("a", "b")):
n = (-1) ** _i * data.sel(species=species)
if normalize:
n = n / abs(n).max()
curve = hv.Curve(n.sel(t=t, method="nearest"), ["x"])
n_min = min(n_min, float(n.min()))
n_max = max(n_max, float(n.max()))
curves.append(curve)
n = data.sel(species="a") + data.sel(species="b")
curve = hv.Curve(n.sel(t=t, method="nearest"), ["x"])
n_min = min(n_min, float(n.min()))
n_max = max(n_max, float(n.max()))
curves.append(curve)
return (curves[0] * curves[1] * curves[2]).redim(n=dict(range=(n_min, n_max)))
[docs]
def get_phases(self, t):
"""Event handler that returns the density curves."""
data = self.data_phase
curves = []
for _i, species in enumerate(("a", "b")):
phase = data.sel(species=species) / (2 * np.pi) - (1.0 if _i == 1 else 0)
curve = hv.Curve(phase.sel(t=t, method="nearest"), ["x"])
curves.append(curve)
return (curves[0] * curves[1]).redim(phase=dict(range=(-1.5, 0.5)))
[docs]
def get_momentum(self, t, normalize):
"""Event handler that returns the momenta curves."""
data = self.data_k
curves = []
n_min = 0
n_max = 0
for _i, species in enumerate(("a", "b")):
n = (-1) ** _i * data.sel(species=species)
curve = hv.Curve(n.sel(t=t, method="nearest"), ["k"])
n_min = min(n_min, float(n.min()))
n_max = max(n_max, float(n.max()))
curves.append(curve)
return (curves[0] * curves[1]).redim(n=dict(range=(n_min, n_max)))
[docs]
def get_plots(self, t, normalize=False, momentum=False, curve="phase"):
"""Return the various subplots."""
if curve == "momentum":
return self.get_momentum(t, normalize=normalize)
elif curve == "phase":
return self.get_phases(t)
else:
return self.get_densities(t, normalize=normalize)
[docs]
def view_hv(self, species="ab", curve="density", image=False, **kw):
self._set_states(image=image)
data = self.data
kdims = ["t", "x"]
ts = data["t"]
dt = np.diff(ts)
if species == "ab":
n = data.sel(species="a") + data.sel(species="b")
else:
n = data.sel(species=species)
if not np.allclose(dt[0], dt):
# Allows uneven grids, but no datashader regridding
image = hv.QuadMesh([data[_k] for _k in kdims] + [n.T], kdims, **kw)
else:
image = regrid(hv.Image(n.T, kdims))
pos_t = hv.streams.PointerX(source=image, x=0).rename(x="t")
selection = SelectionStream()
paramnb.Widgets(
selection, continuous_update=True, callback=selection.event, on_init=True
)
dm = hv.DynamicMap(self.get_plots, streams=[pos_t, selection])
return hv.Layout(image + dm).cols(1)
if hv:
[docs]
class SelectionStream(hv.streams.Stream):
[docs]
normalize = param.Boolean(default=False)
[docs]
curve = param.ObjectSelector(
default="density", objects=["density", "momentum", "phase"]
)