import numpy as np
import matplotlib.pyplot as plt
import warnings
from zope.interface import implementer
from IPython.display import clear_output, display
import gpe.bec
from gpe.utils import AsNumpyMixin, get_smooth_transition
from gpe.interfaces import IStateDFT
from pytimeode.mixins import ArrayStateMixin
from mmfutils.containers import ObjectBase
from mmfutils.contexts import FPS
[docs]
_TINY = np.finfo(float).tiny
@implementer(IStateDFT)
[docs]
class StateFVBase(ArrayStateMixin, AsNumpyMixin, ObjectBase):
r"Units corresponding to $1\mu m = 1$, $\hbar=1$, $1 amu = 1$"
def __init__(
self,
experiment=None,
Nx=2**10,
Lx=100 * u.micron,
nu=25 * u.micron**2 * u.Hz,
t=0,
x_TF=100 * u.micron,
nj=None,
m=1,
mu=None,
cfl=0.5,
ntol=1e-6,
t_final=np.inf,
**kw,
):
"""Hydro state using finite differences.
Attributes:
----------
Nx: int
Number of grid points.
Lx: float
Size of box in microns.
ntol: float
Density regulator (density units).
m: float
Mass of the particles.
x_TF: float
Thomas-Fermi radius (either this or mu should be provided).
mu: float
Chemical potential (either this or x_TF is required, not both).
nj: Array
Density and current.
cfl: float
Courant number.
t_final: float
Final evolution time.
"""
[docs]
self.experiment = experiment
dx = Lx / Nx
[docs]
self.xyz = (np.arange(-dx - (Lx / 2.0), (Lx / 2.0) + dx, dx),)
super().__init__(**kw)
if nj is None:
V_TF = getattr(self, "experiment.V_TF", None)
nj = self.get_initial_nj(V_TF=V_TF)
self.set_nj(nj)
# Save initial n and j for plotting
[docs]
self._nj0 = self.get_nj() # nj.copy()
n_max = nj[0].max()
# self.filter_width = np.sqrt(u.hbar**2 / (2 * self.m * self.g * n_max))
# Recalculate proper dt
[docs]
self.dt = self.get_dt()
# Operator splitting flag
@property
[docs]
def x(self):
"""Flat x abscissa as a numpy array."""
return self.xyz[0]
[docs]
def get_nj(self):
return self.data
[docs]
def set_nj(self, nj, pos=False):
"""Set self.data from n and j."""
n, j = nj
if pos:
n = np.maximum(0, n)
self.data = np.asarray((n, j))
[docs]
def get_n_TF(self, V_TF=None, V_ext=None, g=None):
"""Return the Thomas Fermi density profile n_1D from mu.
Arguments
---------
V_TF : float
Value of V(x_TF) where the density should vanish in the TF limit.
"""
zero = np.zeros(self.x.shape)
if self.mu is not None:
V_TF = self.mu
elif V_TF is None:
V_TF = self.get_V_TF()
if g is None:
warnings.warn("Coupling constant g should have the correct dimensions.")
g = self.experiment.g
if V_ext is None:
V_ext = self.get_Vext()
# V = V_ext + zero
mu_eff = V_TF - V_ext
n = mu_eff / g
n = np.where(n < 0, 0, n)
return n
[docs]
def get_V_TF(self, x_TF=None, V_ext=None):
"""Return the Thomas Fermi chemical potential at x_TF.
Arguments
---------
x_TF : float
Position defining the Thomas Fermi "radius". (The external potential is
evaluated at this position and this is used to get `mu`.)
"""
zero = np.zeros(self.x.shape)
if x_TF is None:
x_TF = self.x_TF
if V_ext is None:
V_ext = self.get_Vext()
V = V_ext + zero
# Minimize along all axes except x which is the 0th axis.
while len(V.shape) > 1:
V = np.min(V, axis=-1)
x = self.x.ravel()
# Find the closest lattice points an perform a polynomial fit
# so we can interpolate to the closest V even if x_TF does not
# lie on a lattice point.
i = np.argmin(abs(x - x_TF))
# Make sure slice has at least 3 points.
i = min(max(i, 1), len(x) - 2)
inds = slice(i - 1, i + 2)
order = 2
V_TF = np.polyval(np.polyfit(x[inds], V[inds], order), x_TF)
return V_TF
[docs]
def get_initial_nj(self, **kw):
"""Setting initial data using Thomas Fermi.
We assume data to have the form ((n, j)) if it must be provided.
"""
n0 = self.get_n_TF(**kw)
return (n0, 0 * n0)
[docs]
def get_Vext(self, d=0):
"""Return potential or it's space derivative."""
Vext = self.experiment.get_Vext(state=self, d=d)
return Vext
[docs]
def get_N(self):
return np.floor(np.sum(self.get_density()) * self.dx)
[docs]
def apply_boundary_condition(self):
"""Apply fixed boundary condition."""
nj = self.get_nj()
# nj[:, 0] = nj[:, -2]
nj[:, 0] = nj[:, 1]
# nj[:, -1] = nj[:, 1]
nj[:, -1] = nj[:, -2]
nj[0] = abs(nj[0])
self.set_nj(nj)
[docs]
def get_density(self):
return self.get_nj()[0]
[docs]
def get_current(self):
return self.get_nj()[1]
[docs]
def get_dt(self):
"""Return dt satisfying von Neumann stability Condition"""
nu = self.experiment.get_nu(self)
h, hu = self.get_nj()
u = (h > self.ntol) * hu / (h + _TINY)
cmax = (
abs(u) + np.sqrt(self.experiment.get_f(state=self, n=h, d=1) / self.m)
).max()
mudiff = self.dx**2 / (2 * (_TINY + nu)) # Unit = T
dt = self.cfl * min(mudiff, self.dx / cmax)
return dt
[docs]
def get_fluxes(self, h, hu, u):
"""Return the fluxes of SWE
Arguments:
h: float
height
hu: float
height times the speed
u: float
speed"""
return hu, hu * u + self.experiment.get_f(state=self, n=h, d=0) / self.m
[docs]
def get_dU_dt_diff(self):
"""Returns time derivative of conserved variables, h & hu for the
diffusion operator"""
nu = self.experiment.get_nu(self)
h, hu = self.get_nj()
d2u = nu * (hu[:-2] - 2 * hu[1:-1] + hu[2:]) / self.dx**2
# Adding a ghost point at each end
return np.array([np.zeros_like(hu), np.array([0] + list(d2u) + [0])])
[docs]
def get_dU_dt_adv(self):
"""Returns time derivative of conserved variables, h & hu
Arguments
---------
"""
h, hu = h_hu = self.get_nj()
u = (h > self.ntol) * hu / (h + _TINY)
f1, f2 = self.get_fluxes(h, hu, u)
f12 = np.array([f1, f2])
alpha = abs(u) + np.sqrt(self.experiment.get_f(state=self, n=h, d=1) / self.m)
alpha_cell = np.maximum(alpha[:-1], alpha[1:])
# Calculate the interface flux
f_cell = (f12[:, :-1] + f12[:, 1:]) / 2
delta_cell = (h_hu[:, :-1] - h_hu[:, 1:]) / 2
F = f_cell + alpha_cell * delta_cell
# Use workspace to speed allocation.
if not hasattr(self, "_dh_hu_dt"):
self._dh_hu_dt = np.zeros_like(h_hu)
self._dh_hu_dt[:, 1:-1] = (-1.0 / self.dx) * (F[:, 1:] - F[:, :-1])
dh_dt, dhu_dt = self._dh_hu_dt
source = self.get_Vext(d=1)
dhu_dt[1:-1] -= h[1:-1] * source[1:-1] / self.m
return (dh_dt, dhu_dt)
[docs]
def compute_dy_dt(self, dy=None):
if dy is None:
# Can we not use self.empty() which is faster?
dy = self.copy()
if self.operator < 0:
dnj_dt = self.get_dU_dt_adv()
else:
dnj_dt = self.get_dU_dt_diff()
self.operator *= -1
dy.set_nj((dnj_dt[0], dnj_dt[1]))
return dy
@property
[docs]
def t_scale(self):
raise NotImplementedError
[docs]
def evolve_to(self, t):
"""Evolve the state to the specified time."""
dt = 0
while self.t < t: # Somewhat dangerous floating-point comparison.
if self.operator < 0:
dt = self.get_dt()
# This slightly convoluted update is to ensure that rounding errors do
# not cause an infinite loop in the previous comparison
if self.t + dt >= t:
dt = t - self.t
self.t = t # Ensure self.t < t fails.
else:
self.t += dt
dy = self.compute_dy_dt()
# Suspicious use of dt here... Should it be zero on the first step? Why?
self.set_nj(self.get_nj() + dt * dy.get_nj(), pos=False)
self.apply_boundary_condition()
[docs]
def evolve(
self,
hist=False,
t_final=None,
steps=100,
skip=10,
show_plot=True,
JT=False,
fname=None,
**kw,
):
"""Evolve SWE in time using Forward Euler
Arguments:
---------
hist: Bool
Returns list of states corresponding to evolver time steps.
t_final: float
Final time points in the evolution
steps : int
Number of intermediate steps to save history at.
skip : int
Number of steps to skip between plots.
"""
if t_final is None:
t_final = self.t_final
if t_final is None:
t_final = self.experiment.t__final * self.experiment.t_unit
# Time start
step_ = 0
history = [self.copy()]
dt = (t_final - self.t) / steps
assert dt > 0
from tqdm import tqdm
fpss = fps = FPS(steps)
if not show_plot:
fpss = tqdm(fps)
for step in fpss:
if not fps:
break
self.evolve_to(self.t + dt)
if hist:
history.append(self.copy())
if show_plot and step_ % skip == 0:
plt.clf()
self.plot(**kw)
display(plt.gcf())
clear_output(wait=True)
plt.close("all")
if hist:
return history
return history[-1]
[docs]
def plot(self, ax=None): # pragma: no cover
"""Plot the data"""
n, j = self.get_nj()
x = self.x
t = self.t
fig = plt.figure()
gs = fig.add_gridspec(1, 1)
# Density
if ax is None:
ax = fig.add_subplot(gs[0])
ax.plot(x, n)
ax.set(ylabel="Density (n)", xlabel="x")
ax_ = ax.twinx()
ax_.plot(x, self.get_Vext(), "C2")
ax_.set(ylabel="Vext")
plt.suptitle(f"t={self.t/u.ms} ms, N={self.get_N()}")
[docs]
class StateFV_BEC(StateFVBase):
pass