---
jupytext:
  formats: md:myst,ipynb
  text_representation:
    extension: .md
    format_name: myst
    format_version: 0.13
    jupytext_version: 1.16.7
kernelspec:
  display_name: Python 3 (ipykernel)
  language: python
  name: python3
---

```{code-cell} ipython3
:init_cell: true

import mmf_setup;mmf_setup.nbinit()
from gpe.imports import *
```

# 2D Turbulence

Here we reproduce the results of {cite}`Zhao:2025`, where they stir a 2D BEC and claim
to measure a Kolmogorov decay spectrum.

* $N = 2\times 10^{5}$ atoms.
* $R = r_{\mathrm{tr}} = \qty{22}{micron}$

```{code-cell} ipython3
from gpe.imports import *
from gpe.bec import StateBase, u
from gpe.utils import ExperimentBase, get_good_N, _GPU
from gpe.minimize import MinimizeState

from mmfutils.math.special import mstep

@_GPU.add_non_GPU_methods
class State(StateBase):
    def __init__(self, experiment, kcut, kappa0, **kw):
        self.experiment = experiment
        self.kcut = kcut
        self.kappa0 = kappa0
        super().__init__(**kw)

    def init(self):
        super().init()

        ks = s.basis.kx
        mask = np.logical_and(ks <= self.kcut, ks >= -self.kcut)
        self._phase_k = np.exp(-1j * self.kappa0 * np.where(mask, 0, 1))
        
    def get_Vext_GPU(self):
        V_ext = self.experiment.get_Vext(state=self)
        if (self.initializing or self.t < 0) and getattr(self, "mu", None):
            V_ext -= self.mu
        return V_ext

    def _apply_dissipation(self, dy):
        """Return applying dissipation."""
        fftn, ifftn = self.basis.fftn, self.basis.ifftn
        dy.set_psi(ifftn(self._phase_k * fftn(dy.get_psi())))
        
    def compute_dy_dt(self, dy, subtract_mu=True):
        super().compute_dy_dt(dy, subtract_mu=subtract_mu)
        if not (self.initializing or self.t < 0):
            self._apply_dissipation(dy)
        return dy


class Experiment(ExperimentBase):
    hbar = 1
    m = u.m_Rb87
    w_z = 2*np.pi * (220*u.Hz)
    R_micron = 22.0
    R_Rx = 0.9
    g = 1
    V0_mu = 100.0
    V_dr_healing_length = 1.0

    mu = None
    healing_length_micron = 2.0
    dx_healing_length = 0.1
    n0_micron2 = 200.0   # Background density
    
    # Stirring Potential
    Vstir_Hz = 25
    Vstir_R_micron = 3.5
    Vstir_dr_healing_length = 1.0
    Vstir_mu = 10.0
    t_stir_ms = 16
    rng_seed = 0

    # Dissipation
    kcut_invmicron = 5
    kappa0 = 0.02

    t_final_ms = 56
    
    State = State
    
    def init(self):
        if self.mu is not None and self.healing_length_micron is None:
            self.healing_length = np.sqrt(self.hbar**2 / 2 / self.m / self.mu)
            mu = self.mu

        elif self.mu is None and self.healing_length_micron is not None:
            self.healing_length = self.healing_length_micron * u.micron
            mu = self.hbar**2 / 2 / self.m / self.healing_length**2

        else:
            raise ValueError("Must specify (only) one of `mu` or `healing_length_micron`")

        dx = self.dx_healing_length * self.healing_length
        self.V_R = self.R_micron * u.micron
        Rx = self.V_R / self.R_Rx
        Nx = get_good_N(2*Rx/dx)
        n0 = self.n0_micron2 / u.micron**2

        kcut = self.kcut_invmicron / u.micron
        
        g = mu / n0
        self.state_args = dict(kcut=kcut, kappa0=self.kappa0, Nxyz=(Nx, Nx), Lxyz=(2*Rx, 2*Rx),
                               g=g, mu=mu, m=self.m, hbar=self.hbar)
        self.V0 = self.V0_mu * mu
        self.V_dr = self.V_dr_healing_length * self.healing_length
        self._Vtrap = None
        
        self.Vstir_w = 2*np.pi * self.Vstir_Hz * u.Hz
        self.Vstir_R = self.Vstir_R_micron * u.micron
        self.Vstir_dr = self.Vstir_dr_healing_length * self.healing_length
        self.Vstir_V0 = self.Vstir_mu * mu

        self._n_R_changes = 0
        self.rng = np.random.default_rng(seed=self.rng_seed)
        self._R_stir = 13.5 * u.micron  # initial radius of stirring

        self.t_stir = self.t_stir_ms * u.ms
        self.t_final = self.t_final_ms * u.ms

    def get_Vext(self, state):
        """Return the external potential."""
        if self._Vtrap is None:
            self._Vtrap = self.get_Vtrap(state=state)
        V_ext = self._Vtrap + self.get_Vstir(state=state)
        return V_ext
        
    def get_Vtrap(self, state):
        """Return the static trapping potential."""
        x, y = state.get_xyz()
        r = np.sqrt(x**2 + y**2)
        return self._Vstep(r, R=self.V_R, dr=self.V_dr)
        
    def _Vstep(self, r, R, dr):
        """Smooth step at R of width dr."""
        return mstep(r - (R - dr/2), dr)

    def _Vrod(self, r):
        """Stirring rod potential (central)."""
        return 1-self._Vstep(r, R=self.Vstir_R, dr=self.Vstir_dr)

    def get_Vstir(self, state):
        """Return the dynamic stirring potential."""
        t = state.t
        if t > self.t_stir:
            return 0

        th0 = -np.pi/4
        dth = self.Vstir_w * t

        # TODO: Need to make these random transitions smooth?
        if t >= 400 * u.s * 1e-6 * self._n_R_changes:
            self._n_R_changes += 1
            self._R_stir = self.rng.uniform(12, 15) * u.micron
        
        z0, z1 = self._R_stir * np.exp([1j*(th0 + dth), 1j*(th0 - dth)])
        x, y = state.get_xyz()
        z = x + 1j*y
        return self.Vstir_V0 * (self._Vrod(abs(z-z0)) + self._Vrod(abs(z-z1)))
        
    def get_state(self):
        state = State(experiment=self, **self.state_args)
        return state
        
    def get_initialized_state(self):
        state0 = self.get_state()
        m = MinimizeState(state0)
        print(m.check())
        state1 = m.minimize(use_scipy=True)
        state = self.get_state()
        state.set_psi(state1.get_psi())
        return state

e = Experiment()
s = e.get_initialized_state()
s.plot()
```

```{code-cell} ipython3
from pytimeode.evolvers import EvolverABM
from mmfutils.contexts import FPS

ev = EvolverABM(s, dt=0.1*s.t_scale)

steps = 200
loops = int(e.t_final / ev.dt / steps)

for frame in FPS(loops, timeout=10):
    ev.evolve(steps)
    plt.clf()
    ev.y.plot()
    clear_output(wait=True)
    display(plt.gcf())
```

```{code-cell} ipython3
s0 = e.get_initialized_state()
rng = np.random.default_rng(seed=2)
def get_psi(s0=s0, p=0.1, steps=200, rng=rng):
    s = e.get_state()
    s.set_psi(s0.get_psi() * np.exp(1j*p*rng.normal(size=s0.shape)))
    ev = EvolverABM(s, dt=0.2*s.t_scale)
    ev.evolve(steps)
    return ev.y.get_psi()

psis = [get_psi() for n in range(10)]
n_ = np.abs(np.mean(psis, axis=0))**2
print(s0.get_N(), s0.integrate(n_))
#n = get_n()
#plt.imshow(n)
```

```{code-cell} ipython3

```
