---
execution:
  timeout: 1
jupytext:
  cell_metadata_json: true
  encoding: '# -*- coding: utf-8 -*-'
  formats: md:myst,ipynb
  notebook_metadata_filter: execution
  text_representation:
    extension: .md
    format_name: myst
    format_version: 0.13
    jupytext_version: 1.16.4
kernelspec:
  display_name: Python 3 (ipykernel)
  language: python
  name: python3
---

(sec:WritingTests)=
Writing Tests
=============
Here we will walk through the process of writing some tests for code.  Specifically, we
will test the code in {py:module}`gpe.tube` that implements the {ref}`sec:drGPE` and
{ref}`sec:NPSEQ`.

In this case, we are testing a rather complete system of code that generates a 1D
approximation of a 3D simulation.  A reasonably strategy to start might be to simulate
both the 3D system and then the 1D approximation, making sure that they agree.

Of course, for tests, we want this to run quickly, so we will push the limits of the
calculation, but first we just get something working.  The most general code allows for
dynamic rescaling of the trapping frequencies as well as dynamics in the $x$ direction,
so we start with a simple 3D simulation that models this.

```{code-cell}
import numpy as np, matplotlib.pyplot as plt
import gpe.utils, gpe.bec, gpe.minimize


class StateMixin(gpe.bec.StateHOMixin):
    def __init__(self, experiment, **kw):
        self.experiment = experiment
        super().__init__(**kw)

    def get_ws(self, t=None):
        """Return the trapping frequencies."""
        if t is None:
            t = self.t
        e = self.experiment
        ws = [
            w * (1 + e.dw * np.sin(dw * t))
            for w, dw in zip(e.ws, e.dws)
        ]
        return ws

    def _get_Vext_(self, gpu=True):
        Vext = super()._get_Vext_()
        xyz = self._xyz_ if gpu else self.xyz
        # Could add pokey here.
        return Vext


class State(StateMixin, gpe.bec.StateBase):
    pass


class Experiment(gpe.utils.ExperimentBase):
    hbar = m = 1
    ws = (2.0, 10.0, 15.0)  # Trapping frequencies
    dws = (10.0, -12.0, 14.0)  # Frequency to change trapping frequencies
    dw = 0.1  # Fraction of amplitude of trapping frequency change.
    
    # These choices are not obvious: see below for how we selected them.
    Nxyz = (80, 45, 45)
    Lxyz = (7.0, 4.0, 4.0)
    mu = 15
    g = 1.0
    
    State = State
    
    def init(self):
        self.state_args = dict(
            experiment=self,
            hbar=self.hbar,
            m=self.m,
            Lxyz=self.Lxyz,
            Nxyz=self.Nxyz,
            mu=self.mu,
            g=self.g)
        
    def get_state(self):
        return self.State(**self.state_args)
    
    def get_initialized_state(self):
        s0 = self.get_state()
        s = gpe.minimize.MinimizeState(s0, fix_N=False).minimize(use_scipy=True)
        state = self.get_state()
        state.set_psi(s.get_psi())
        return state
```

:::{admonition} Setting Parameters

We need to choose reasonable parameters to make sure the tests work and make sense.  I
want the tube to be long (quasi-1D), so choose $\omega_x \ll \omega_{y,z}$.  I also want
to make sure that at least one radial mode is occupied in each direction:
\begin{gather*}
  \mu > \hbar \frac{\omega_x + \omega_y + \omega_z}{2}, \qquad
  V_{TF} = \mu - \hbar \frac{\omega_x + \omega_y + \omega_z}{2} > 0.
\end{gather*}
For UV convergence, the lattice spacing should be at most half the healing length:
\begin{gather*}
  2\d{x} < h = \frac{\hbar}{\sqrt{2m\mu}}.
\end{gather*}
Finally, the box must be large enough so that it can hold the cloud: we can estimate
this with the TF radius in each direction, but need to include a margin for the decay.
If we want the trap ground state to decay by a factor of $\epsilon \approx 10^{-12}$,
then we need to make sure that $e^{-x^2/a^2} < \epsilon$ or $x < \sqrt{-\ln\epsilon}a
\approx 5a$ where $a = \sqrt{\hbar/m\omega}$ is the corresponding trap length.  We should also include several healing lengths.
\begin{gather*}
  R_i = \omega_i^{-1}\sqrt{\frac{2V_{TF}}{m}}, \qquad
  L_{i} > 2R_i + 5h + 5a_i
\end{gather*}

Starting with a numerical guess for the frequencies, we arrive at the following estimates:
:::

```{code-cell}
m, hbar = Experiment.m, Experiment.hbar
ws = Experiment.ws
E0 = hbar * sum(ws) / 2
V_TF = 0.1*E0
mu = V_TF + E0
h = hbar / np.sqrt(2 * m * mu)
dx = h / 2.0
Rxyz = np.divide(np.sqrt(2 * V_TF / m), ws)
axyz = np.sqrt(np.divide(hbar/m, ws))
Lxyz = 2*np.maximum(Rxyz + 10*h, 5*axyz)
Nxyz = list(map(gpe.utils.get_good_N, Lxyz/dx))
print(f"{Lxyz=}, {Nxyz=}, {mu=}, {V_TF=}")
```

```{code-cell}
e = Experiment()
s0 = e.get_state()
s0.plot()
%time s = e.get_initialized_state()
s.plot(log=True)
```

This works, and we see we have a good estimate of the box size, but the lattice is
pretty big, and this might take too long for an actual test.  One strategy is to run a long simulation, then save the results and use those to test against.

```{code-cell}
from tqdm import tqdm
from pytimeode.evolvers import EvolverABM

T = 2*np.pi / min(np.abs(e.dws)) / 4
Nt = 20
dT = T/Nt
dt = 0.1*s.t_scale
steps = int(max(np.ceil(dT/dt), 2))
dt = dT / steps
states = []
ev = EvolverABM(s, dt=dt)
for frame in tqdm(range(Nt)):
    ev.evolve(steps)
    states.append(ev.get_y())
ev.y.plot()
```

```{code-cell}
np.save('test_4.npy', ev.y.get_psi())
```

```{code-cell}
fig, ax = plt.subplots()
x = s.xyz[0].ravel()
t = [s.t for s in states]
n = np.array([s.get_density_x() for s in states])
mesh = ax.pcolormesh(t, x, (n - n[0]).T)
fig.colorbar(mesh, ax=ax)
```

## Tube Code

```{code-cell}
%load_ext autoreload
%autoreload 2
import gpe.tube
class StateTube(StateMixin, gpe.tube.StateGPEdrZ):
    pass

et = Experiment(State=StateTube)
st0 = et.get_state()
st0.plot()
%time st = et.get_initialized_state()
st.plot()
```

The first thing we should test is that the particle numbers are similar:

```{code-cell}
print(s.get_N()/st.get_N() - 1)
assert np.allclose(s.get_N(), st.get_N(), rtol=0.008)
```

The second test is that the integrated densities match:

```{code-cell}
fig, ax = plt.subplots()
x = st.xyz[0]
ax.plot(x, s.get_density_x(), "-", label='3D')
ax.plot(x, st.get_density_x(), ":", label='Tube')
ax.set(xlabel="x", ylabel="$n_{1D}$")
ax.legend()
assert np.allclose(s.get_density_x(), st.get_density_x(), atol=2e-4, rtol=0.008)
```

The third test is the central densities.  Note: this is an approximation unless we solve exactly for the radial wavefunction.  One must choose whether the result is more close to the TF approximation (large number of radial modes) or the HO approximation. In this case, we have only a single mode occupied, so the HO approximation is much better:

```{code-cell}
Nx, Ny, Nz = s.basis.Nxyz
x = s.xyz[0].ravel()
fig, ax = plt.subplots()
ax.plot(x, s.get_density()[:, Ny//2, Nz//2], '-', label="3D")
ax.plot(x, st.get_central_density(TF=True), '--', label="TF")
ax.plot(x, st.get_central_density(TF=False), ':', label="HO")
ax.set(xlabel="x", ylabel="central density")
ax.legend()
```

```{code-cell}
statest = []
ev = EvolverABM(st, dt=dt)
for frame in tqdm(range(Nt)):
    ev.evolve(steps)
    statest.append(ev.get_y())
```

```{code-cell}
fig, axs = plt.subplots(2, 1)
x = s.xyz[0].ravel()
t = [s.t for s in states]
n = np.array([s.get_density_x() for s in states])
mesh = axs[0].pcolormesh(t, x, (n - n[0]).T)
fig.colorbar(mesh, ax=axs[0])

x = st.xyz[0].ravel()
t = [st.t for st in statest]
n = np.array([st.get_density_x() for st in statest])
mesh = axs[1].pcolormesh(t, x, (n - n[0]).T)
fig.colorbar(mesh, ax=axs[1])
```

```{code-cell}
from scipy.interpolate import InterpolatedUnivariateSpline

fig, axs = plt.subplots(2, 1)

x = s.xyz[0].ravel()
t = [s.t for s in states]
n = np.array([s.get_density_x() for s in states])
mesh = axs[0].pcolormesh(t, x, (n - n[0]).T)
fig.colorbar(mesh, ax=axs[0])

X = []
t = [st.t for st in statest]
st = statest[0]
x0 = st.xyz[0]
n0 = st.get_density_x()
get_n0 = InterpolatedUnivariateSpline(x0, n0)
n_n0 = []
for st in statest:
    x = st.xyz[0]
    n_n0.append(st.get_density_x() - get_n0(x))
    X.append(x)

n_n0 = np.array(n_n0)
T = np.array([t for x in st.xyz[0]])
X = np.transpose(X)
mesh = axs[1].pcolormesh(T, X, n_n0.T, shading='auto')
fig.colorbar(mesh, ax=axs[1])
```

```{code-cell}
st = statest[-1]
st.get_density() - st.get_density_x()
```

```{code-cell}
N0 = [np.trapz(st.get_density_x(), st.xyz[0]) for st in statest]
N1 = [np.trapz(st.get_density_x(), x0) for st in statest]
N0, N1
```

### Strange Issue

```{code-cell}
%load_ext autoreload
%autoreload 2
import gpe.tube
class StateTube(StateMixin, gpe.tube.StateGPEdrZ):
    pass

et = Experiment(State=StateTube)
st = et.get_state()
```

```{code-cell}
ts = np.linspace(0, 3, 200)
st._tsqs = []
lams = np.array([st.get_lambdas(t)[0] for t in ts])
ws = np.array([st.get_ws(t) for t in ts])
ts_, lams_ = ts, lams
plt.plot(ts_, lams_, ':')
```

## Previous Issues

+++

While developing the tests in this notebook, we discovered several issues in the code (see `Notes.md`).  Here are some snipits of code that demonstrated these issues - they will be included as unit tests.

```{code-cell}
import numpy as np, matplotlib.pyplot as plt
import gpe.utils, gpe.bec, gpe.minimize


class State(gpe.bec.StateHOMixin, gpe.bec.StateBase):
    def get_ws(self, t=None):
        return (2.0, 10.0, 15.0)

    def _get_Vext_(self, gpu=True):
        Vext = super()._get_Vext_()
        xyz = self._xyz_ if gpu else self.xyz
        x = xyz[0]
        
        # Previously this was true
        Vext.flags['WRITEABLE'] = True
        
        Vext += x**2  # This is dangerous because of mutation...  Now should raise exception.
        return Vext


s = State()
assert np.allclose(s.get_energy(), s.get_energy())
```

```{code-cell}

```
