"""Example code from the tutorial."""
import numpy as np
import matplotlib.pyplot as plt
from gpe.bec import StateGPEBase, HOMixin
from gpe.minimize import MinimizeState
[docs]
class StateHOConvergence1(HOMixin, StateGPEBase):
"""State with tools to check convergence for a 1D HO GPE.
This is a simple state with default parameters forming a trapped gas in a 1D
harmonic oscillator. It has some functions for checking and plotting convergence.
The state is specified by the parameter `dmu = mu / (hbar*w/2) - 1`. If `dmu = 0`,
we will have no interactions (but don't do this: use a small, non-zero value), while
larger values will approach the TF limit.
"""
hbar = m = 1.0
def __init__(self, **kw):
for key in list(kw):
if key in StateHOConvergence1.__dict__:
setattr(self, key, kw.pop(key))
mu = self.hbar * self.w / 2 * (1 + self.dmu)
defaults = dict(
hbar=self.hbar,
m=self.m,
g=self.g,
mu=mu,
Lxyz=(self.Lx,),
Nxyz=(self.Nx,),
ws=(self.w,),
)
defaults.update(kw)
super().__init__(**defaults)
[docs]
def init(self):
super().init()
[docs]
def get_convergence(self, full_output=False):
"""Return `(ir, uv)` convergence factors.
These are the ratios of the density at the edge vs maximum density in position
(IR) and momentum (UV) space.
Arguments
---------
full_output : bool
If `True`, return `(ir, uv, x, nx, k, nk)`.
"""
nx = abs(self.get_psi()) ** 2
nk = np.fft.fftshift(abs(self.basis.fftn(self.get_psi())) ** 2)
ir = max(nx[0], nx[-1]) / nx.max()
uv = max(nk[0], nk[-1]) / nk.max()
if full_output:
x = self.get_xyz()[0].ravel()
k = np.fft.fftshift(self.basis.kx.ravel())
return (ir, uv, x, nx, k, nk)
else:
return (ir, uv)
[docs]
def plot_convergence(self, ax=None, legend=True, **kw):
"""Plot the convergence in axis `axs[0]`.
Arguments
---------
ax : Axes
Plot convergence in ax if provided.
legend : bool
If `True`, include the legend.
**kw : dict
All other arguments are passed to `plot()`.
Returns
-------
ax : Axes
"""
if ax is None:
fig, ax = plt.subplots()
ir, uv, x, nx, k, nk = self.get_convergence(full_output=True)
ax.semilogy(x / abs(x.max()), nx / nx.max(), label="IR", **kw)
ax.semilogy(k / abs(k.max()), nk / nk.max(), label="UV", **kw)
if legend:
ax.legend()
return ax
[docs]
def get_gtilde(self):
"""Return the dimensionless interaction parameter."""
return self.g * self.get_N() * np.sqrt(self.m / self.hbar**3 / self.w)
[docs]
def plot(self, axs=None, plot_convergence=False, **kw):
"""Plot the state in axis `ax`.
Arguments
---------
axs : [Axes]
List of axes. Plot in axs[0] if provided. plot the convergence in
`axs[1]`, if provided and `plot_convergence` is `True`.
plot_convergence : bool
If `True`, include the convergence plot.
**kw : dict
All other arguments are passed to `plot()`.
Returns
-------
ax : Axes
"""
if axs is None:
fig, ax = plt.subplots(figsize=(4, 3))
axs = [ax]
else:
ax = axs[0]
x = self.get_xyz()[0].ravel()
n = self.get_density()
ax.plot(x, n, **kw)
N, E, t = self.get_N(), self.get_energy(), self.t
g, m, hbar, w = self.g, self.m, self.hbar, self.ws[0]
g_ = self.get_gtilde()
with np.errstate(all="ignore"):
ir, uv = map(np.log10, self.get_convergence())
title = []
if t != 0:
title.append(f"{t=:.4f}")
title.append(rf"$\tilde{{g}}$={g_:0.5g}")
title.extend([f"{N=:.2g}", f"{E=:.2g}", f"{ir=:.2g}", f"{uv=:.2g}"])
ax.set(title=", ".join(title))
if plot_convergence:
if len(axs) > 1:
axc = axs[1]
else:
axc = ax.inset_axes((0, 0.75, 0.25, 0.25))
axs.append(axc)
axc.yaxis.tick_right()
self.plot_convergence(ax=axc, legend=False)
return axs
[docs]
def get_initialized_state(self, fix_N=False, minimize_kw={}):
"""Return a minimized state."""
s0 = self.copy()
s0.set_initial_state()
m = MinimizeState(s0, fix_N=fix_N)
m.check()
s1 = m.minimize(**minimize_kw)
s = self.copy()
s.set_psi(s1.get_psi())
s.pre_evolve_hook() # Finish initializing
return s
[docs]
class StateHOConvergence2(StateHOConvergence1):
"""State with tools to check convergence for a 2D HO GPE.
This is a simple state with default parameters forming a trapped gas in a 2D
harmonic oscillator. It has some functions for checking and plotting convergence.
The state is specified by the parameter `dmu = mu / (hbar*w/2) - 1`. If `dmu = 0`,
we will have no interactions (but don't do this: use a small, non-zero value), while
larger values will approach the TF limit.
"""
hbar = m = 1.0
def __init__(self, **kw):
for key in list(kw):
if key in StateHOConvergence2.__dict__:
setattr(self, key, kw.pop(key))
mu = self.hbar * self.w * (1 + self.dmu)
kw.update(
hbar=self.hbar,
m=self.m,
g=self.g,
mu=mu,
Lxyz=(self.Lx, self.Lx),
Nxyz=(self.Nx, self.Nx),
ws=(self.w, self.w),
)
super().__init__(**kw)
[docs]
def get_convergence(self, full_output=False):
"""Return `(ir, uv)` convergence factors.
These are the ratios of the density at the edge vs maximum density in position
(IR) and momentum (UV) space.
Arguments
---------
full_output : bool
If `True`, return `(ir, uv, r, nr, k, nk)`.
"""
nr = abs(self.get_psi()) ** 2
nk = np.fft.fftshift(abs(self.basis.fftn(self.get_psi())) ** 2)
ir = max(nr[[0, -1], :].max(), nr[:, [0, -1]].max()) / nr.max()
uv = max(nk[[0, -1], :].max(), nk[:, [0, -1]].max()) / nk.max()
if full_output:
r = np.sqrt(sum(x**2 for x in self.get_xyz())).ravel()
k = np.fft.fftshift(np.sqrt(sum(k**2 for k in self.basis._pxyz))).ravel()
inds_r = np.argsort(r)
inds_k = np.argsort(k)
return (ir, uv, r[inds_r], nr.ravel()[inds_r], k[inds_k], nk.ravel()[inds_k])
else:
return (ir, uv)
[docs]
def plot_convergence(self, ax=None, legend=True, **kw):
"""Plot the convergence in axis `axs[0]`.
Arguments
---------
ax : Axes
Plot convergence in ax if provided.
legend : bool
If `True`, include the legend.
**kw : dict
All other arguments are passed to `plot()`.
Returns
-------
ax : Axes
"""
if ax is None:
fig, ax = plt.subplots()
ir, uv, x, nx, k, nk = self.get_convergence(full_output=True)
ax.semilogy(x / abs(x.max()), nx / nx.max(), label="IR", **kw)
ax.semilogy(k / abs(k.max()), nk / nk.max(), label="UV", **kw)
if legend:
ax.legend()
return ax
[docs]
def plot(self, axs=None, plot_convergence=False, **kw):
"""Plot the state in axis `ax`.
Arguments
---------
axs : [Axes]
List of axes. Plot in axs[0] if provided. plot the convergence in
`axs[1]`, if provided and `plot_convergence` is `True`.
plot_convergence : bool
If `True`, include the convergence plot.
**kw : dict
All other arguments are passed to `plot()`.
Returns
-------
ax : Axes
"""
if axs is None:
fig, ax = plt.subplots(figsize=(4, 3))
axs = [ax]
else:
ax = axs[0]
r = np.sqrt(sum(x**2 for x in self.get_xyz())).ravel()
n = self.get_density().ravel()
inds = np.argsort(r)
r, n = r[inds], n[inds]
ax.plot(r, n, **kw)
N, E, t = self.get_N(), self.get_energy(), self.t
g, m, hbar, w = self.g, self.m, self.hbar, self.ws[0]
g_ = self.get_gtilde()
with np.errstate(all="ignore"):
ir, uv = map(np.log10, self.get_convergence())
title = []
if t != 0:
title.append(f"{t=:.4f}")
title.append(rf"$\tilde{{g}}$={g_:0.5g}")
title.extend([f"{N=:.2g}", f"{E=:.2g}", f"{ir=:.2g}", f"{uv=:.2g}"])
ax.set(title=", ".join(title))
if plot_convergence:
if len(axs) > 1:
axc = axs[1]
else:
axc = ax.inset_axes((0.75, 0.75, 0.25, 0.25))
axs.append(axc)
axc.yaxis.tick_right()
self.plot_convergence(ax=axc, legend=False)
return axs