Source code for gpe.visualize

from __future__ import division, print_function, with_statement

import numpy as np
import xarray as xr

try:
    import holoviews as hv
    import param, paramnb
    from holoviews.operation.datashader import regrid  # , datashade
except (ImportError, RuntimeError):
[docs] hv = None
[docs] class SimulationMixin2: """Mixin for Simulation classes that provides holoview visualization."""
[docs] opts = dict(width=600, height=300)
@property
[docs] def psis(self): self._set_states() return self._psis
@property
[docs] def phases(self): self._set_states() return np.angle(self._psis)
@property
[docs] def psis_k(self): self._set_states() return self._psis_k
@property
[docs] def ns(self): self._set_states() return self._ns
[docs] def _set_states(self, image=None): """Load all of the states and return (ts, psis).""" if not hasattr(self, "_states") or (image is not None and self._image != image): if image: keys = sorted( [ key for key in self.frames.keys() if self.frames.isiterable(key) and key[1] == self.experiment.t__image ] ) ts_ = np.array([key[0] for key in keys]) else: ts_ = [_t for _t in self.ts_ if _t in self.saved_ts_] states = [self.get_state(t_, image=image) for t_ in ts_] psis = np.asarray([_state[...] for _state in states]) psis_k = np.fft.fftshift(np.fft.fft(psis, axis=-1), axes=[-1]) if image: ns = np.asarray( [_state.experiment.simulate_image(_state)[1] for _state in states] ) xs = _state.experiment.simulate_image(states[0])[0] else: ns = np.asarray([_state.get_density() for _state in states]) xs = states[0].xyz[0].ravel() self._xs = xs self._ts_ = ts_ self._states = states self._psis = psis self._psis_k = psis_k self._ns = ns self._image = image
@property
[docs] def data(self, skip=1): """Return an xarray with the density data for plotting.""" data = xr.DataArray( self.ns[:, :, ::skip], name="n", dims=("t", "species", "x"), coords=dict(species=["a", "b"], x=self._xs[::skip], t=self._ts_), ).transpose("species", "t", "x") return data
@property
[docs] def data_phase(self, skip=1): """Return an xarray with the density data for plotting.""" data = xr.DataArray( self.phases[:, :, ::skip], name="phase", dims=("t", "species", "x"), coords=dict(species=["a", "b"], x=self._xs[::skip], t=self._ts_), ).transpose("species", "t", "x") return data
@property
[docs] def data_k(self, skip=1): """Return an xarray with the momentum data for plotting.""" data = xr.DataArray( abs(self.psis_k[:, :, ::skip]) ** 2, name="n_k", dims=("t", "species", "k"), coords=dict( species=["a", "b"], k=np.fft.fftshift(self._states[0].kxyz[0][0].ravel()[::skip]), t=self._ts_, ), ).transpose("species", "t", "k") return data
[docs] def get_densities(self, t, normalize): """Event handler that returns the density curves.""" data = self.data curves = [] n_min = 0 n_max = 0 for _i, species in enumerate(("a", "b")): n = (-1) ** _i * data.sel(species=species) if normalize: n = n / abs(n).max() curve = hv.Curve(n.sel(t=t, method="nearest"), ["x"]) n_min = min(n_min, float(n.min())) n_max = max(n_max, float(n.max())) curves.append(curve) n = data.sel(species="a") + data.sel(species="b") curve = hv.Curve(n.sel(t=t, method="nearest"), ["x"]) n_min = min(n_min, float(n.min())) n_max = max(n_max, float(n.max())) curves.append(curve) return (curves[0] * curves[1] * curves[2]).redim(n=dict(range=(n_min, n_max)))
[docs] def get_phases(self, t): """Event handler that returns the density curves.""" data = self.data_phase curves = [] for _i, species in enumerate(("a", "b")): phase = data.sel(species=species) / (2 * np.pi) - (1.0 if _i == 1 else 0) curve = hv.Curve(phase.sel(t=t, method="nearest"), ["x"]) curves.append(curve) return (curves[0] * curves[1]).redim(phase=dict(range=(-1.5, 0.5)))
[docs] def get_momentum(self, t, normalize): """Event handler that returns the momenta curves.""" data = self.data_k curves = [] n_min = 0 n_max = 0 for _i, species in enumerate(("a", "b")): n = (-1) ** _i * data.sel(species=species) curve = hv.Curve(n.sel(t=t, method="nearest"), ["k"]) n_min = min(n_min, float(n.min())) n_max = max(n_max, float(n.max())) curves.append(curve) return (curves[0] * curves[1]).redim(n=dict(range=(n_min, n_max)))
[docs] def get_plots(self, t, normalize=False, momentum=False, curve="phase"): """Return the various subplots.""" if curve == "momentum": return self.get_momentum(t, normalize=normalize) elif curve == "phase": return self.get_phases(t) else: return self.get_densities(t, normalize=normalize)
[docs] def view_hv(self, species="ab", curve="density", image=False, **kw): self._set_states(image=image) data = self.data kdims = ["t", "x"] ts = data["t"] dt = np.diff(ts) if species == "ab": n = data.sel(species="a") + data.sel(species="b") else: n = data.sel(species=species) if not np.allclose(dt[0], dt): # Allows uneven grids, but no datashader regridding image = hv.QuadMesh([data[_k] for _k in kdims] + [n.T], kdims, **kw) else: image = regrid(hv.Image(n.T, kdims)) pos_t = hv.streams.PointerX(source=image, x=0).rename(x="t") selection = SelectionStream() paramnb.Widgets( selection, continuous_update=True, callback=selection.event, on_init=True ) dm = hv.DynamicMap(self.get_plots, streams=[pos_t, selection]) return hv.Layout(image + dm).cols(1)
if hv:
[docs] class SelectionStream(hv.streams.Stream):
[docs] normalize = param.Boolean(default=False)
[docs] curve = param.ObjectSelector( default="density", objects=["density", "momentum", "phase"] )