Source code for gpe.Examples.tutorial

"""Example code from the tutorial."""

import numpy as np
import matplotlib.pyplot as plt

from gpe.bec import StateGPEBase, HOMixin
from gpe.minimize import MinimizeState


[docs] class StateHOConvergence1(HOMixin, StateGPEBase): """State with tools to check convergence for a 1D HO GPE. This is a simple state with default parameters forming a trapped gas in a 1D harmonic oscillator. It has some functions for checking and plotting convergence. The state is specified by the parameter `dmu = mu / (hbar*w/2) - 1`. If `dmu = 0`, we will have no interactions (but don't do this: use a small, non-zero value), while larger values will approach the TF limit. """ hbar = m = 1.0
[docs] Nx = 64
[docs] Lx = 14.0
[docs] dmu = 0.74
[docs] w = 1.0
[docs] g = 1.0
def __init__(self, **kw): for key in list(kw): if key in StateHOConvergence1.__dict__: setattr(self, key, kw.pop(key)) mu = self.hbar * self.w / 2 * (1 + self.dmu) defaults = dict( hbar=self.hbar, m=self.m, g=self.g, mu=mu, Lxyz=(self.Lx,), Nxyz=(self.Nx,), ws=(self.w,), ) defaults.update(kw) super().__init__(**defaults)
[docs] def init(self): super().init()
[docs] def get_convergence(self, full_output=False): """Return `(ir, uv)` convergence factors. These are the ratios of the density at the edge vs maximum density in position (IR) and momentum (UV) space. Arguments --------- full_output : bool If `True`, return `(ir, uv, x, nx, k, nk)`. """ nx = abs(self.get_psi()) ** 2 nk = np.fft.fftshift(abs(self.basis.fftn(self.get_psi())) ** 2) ir = max(nx[0], nx[-1]) / nx.max() uv = max(nk[0], nk[-1]) / nk.max() if full_output: x = self.get_xyz()[0].ravel() k = np.fft.fftshift(self.basis.kx.ravel()) return (ir, uv, x, nx, k, nk) else: return (ir, uv)
[docs] def plot_convergence(self, ax=None, legend=True, **kw): """Plot the convergence in axis `axs[0]`. Arguments --------- ax : Axes Plot convergence in ax if provided. legend : bool If `True`, include the legend. **kw : dict All other arguments are passed to `plot()`. Returns ------- ax : Axes """ if ax is None: fig, ax = plt.subplots() ir, uv, x, nx, k, nk = self.get_convergence(full_output=True) ax.semilogy(x / abs(x.max()), nx / nx.max(), label="IR", **kw) ax.semilogy(k / abs(k.max()), nk / nk.max(), label="UV", **kw) if legend: ax.legend() return ax
[docs] def get_gtilde(self): """Return the dimensionless interaction parameter.""" return self.g * self.get_N() * np.sqrt(self.m / self.hbar**3 / self.w)
[docs] def plot(self, axs=None, plot_convergence=False, **kw): """Plot the state in axis `ax`. Arguments --------- axs : [Axes] List of axes. Plot in axs[0] if provided. plot the convergence in `axs[1]`, if provided and `plot_convergence` is `True`. plot_convergence : bool If `True`, include the convergence plot. **kw : dict All other arguments are passed to `plot()`. Returns ------- ax : Axes """ if axs is None: fig, ax = plt.subplots(figsize=(4, 3)) axs = [ax] else: ax = axs[0] x = self.get_xyz()[0].ravel() n = self.get_density() ax.plot(x, n, **kw) N, E, t = self.get_N(), self.get_energy(), self.t g, m, hbar, w = self.g, self.m, self.hbar, self.ws[0] g_ = self.get_gtilde() with np.errstate(all="ignore"): ir, uv = map(np.log10, self.get_convergence()) title = [] if t != 0: title.append(f"{t=:.4f}") title.append(rf"$\tilde{{g}}$={g_:0.5g}") title.extend([f"{N=:.2g}", f"{E=:.2g}", f"{ir=:.2g}", f"{uv=:.2g}"]) ax.set(title=", ".join(title)) if plot_convergence: if len(axs) > 1: axc = axs[1] else: axc = ax.inset_axes((0, 0.75, 0.25, 0.25)) axs.append(axc) axc.yaxis.tick_right() self.plot_convergence(ax=axc, legend=False) return axs
[docs] def get_initialized_state(self, fix_N=False, minimize_kw={}): """Return a minimized state.""" s0 = self.copy() s0.set_initial_state() m = MinimizeState(s0, fix_N=fix_N) m.check() s1 = m.minimize(**minimize_kw) s = self.copy() s.set_psi(s1.get_psi()) s.pre_evolve_hook() # Finish initializing return s
[docs] class StateHOConvergence2(StateHOConvergence1): """State with tools to check convergence for a 2D HO GPE. This is a simple state with default parameters forming a trapped gas in a 2D harmonic oscillator. It has some functions for checking and plotting convergence. The state is specified by the parameter `dmu = mu / (hbar*w/2) - 1`. If `dmu = 0`, we will have no interactions (but don't do this: use a small, non-zero value), while larger values will approach the TF limit. """ hbar = m = 1.0
[docs] Nx = 64
[docs] Lx = 14.0
[docs] dmu = 0.74
[docs] w = 1.0
[docs] g = 1.0
def __init__(self, **kw): for key in list(kw): if key in StateHOConvergence2.__dict__: setattr(self, key, kw.pop(key)) mu = self.hbar * self.w * (1 + self.dmu) kw.update( hbar=self.hbar, m=self.m, g=self.g, mu=mu, Lxyz=(self.Lx, self.Lx), Nxyz=(self.Nx, self.Nx), ws=(self.w, self.w), ) super().__init__(**kw)
[docs] def get_convergence(self, full_output=False): """Return `(ir, uv)` convergence factors. These are the ratios of the density at the edge vs maximum density in position (IR) and momentum (UV) space. Arguments --------- full_output : bool If `True`, return `(ir, uv, r, nr, k, nk)`. """ nr = abs(self.get_psi()) ** 2 nk = np.fft.fftshift(abs(self.basis.fftn(self.get_psi())) ** 2) ir = max(nr[[0, -1], :].max(), nr[:, [0, -1]].max()) / nr.max() uv = max(nk[[0, -1], :].max(), nk[:, [0, -1]].max()) / nk.max() if full_output: r = np.sqrt(sum(x**2 for x in self.get_xyz())).ravel() k = np.fft.fftshift(np.sqrt(sum(k**2 for k in self.basis._pxyz))).ravel() inds_r = np.argsort(r) inds_k = np.argsort(k) return (ir, uv, r[inds_r], nr.ravel()[inds_r], k[inds_k], nk.ravel()[inds_k]) else: return (ir, uv)
[docs] def plot_convergence(self, ax=None, legend=True, **kw): """Plot the convergence in axis `axs[0]`. Arguments --------- ax : Axes Plot convergence in ax if provided. legend : bool If `True`, include the legend. **kw : dict All other arguments are passed to `plot()`. Returns ------- ax : Axes """ if ax is None: fig, ax = plt.subplots() ir, uv, x, nx, k, nk = self.get_convergence(full_output=True) ax.semilogy(x / abs(x.max()), nx / nx.max(), label="IR", **kw) ax.semilogy(k / abs(k.max()), nk / nk.max(), label="UV", **kw) if legend: ax.legend() return ax
[docs] def plot(self, axs=None, plot_convergence=False, **kw): """Plot the state in axis `ax`. Arguments --------- axs : [Axes] List of axes. Plot in axs[0] if provided. plot the convergence in `axs[1]`, if provided and `plot_convergence` is `True`. plot_convergence : bool If `True`, include the convergence plot. **kw : dict All other arguments are passed to `plot()`. Returns ------- ax : Axes """ if axs is None: fig, ax = plt.subplots(figsize=(4, 3)) axs = [ax] else: ax = axs[0] r = np.sqrt(sum(x**2 for x in self.get_xyz())).ravel() n = self.get_density().ravel() inds = np.argsort(r) r, n = r[inds], n[inds] ax.plot(r, n, **kw) N, E, t = self.get_N(), self.get_energy(), self.t g, m, hbar, w = self.g, self.m, self.hbar, self.ws[0] g_ = self.get_gtilde() with np.errstate(all="ignore"): ir, uv = map(np.log10, self.get_convergence()) title = [] if t != 0: title.append(f"{t=:.4f}") title.append(rf"$\tilde{{g}}$={g_:0.5g}") title.extend([f"{N=:.2g}", f"{E=:.2g}", f"{ir=:.2g}", f"{uv=:.2g}"]) ax.set(title=", ".join(title)) if plot_convergence: if len(axs) > 1: axc = axs[1] else: axc = ax.inset_axes((0.75, 0.75, 0.25, 0.25)) axs.append(axc) axc.yaxis.tick_right() self.plot_convergence(ax=axc, legend=False) return axs