Source code for gpe.gpu.bec
"""GPU versions of the code.
Requires cupy and the various NVIDIA dependencies.
"""
from . import cupy as cp
from pytimeode import mixins_gpu
from .bases import PeriodicBasisGPU
from .. import bec
[docs]
class StateBase(mixins_gpu.GPUArrayStateMixin, bec.StateBase):
if cp:
[docs]
asnumpy = staticmethod(cp.asnumpy)
def __init__(
self,
basis=None,
# Specify either basis or the following
Nxyz=(2**5, 2**5, 2**5),
Lxyz=(30 * u.micron, 50 * u.micron, 50 * u.micron),
symmetric_grid=False,
twist=None,
**kw,
):
args = dict(Nxyz=Nxyz, Lxyz=Lxyz, symmetric_lattice=symmetric_grid)
if basis is None:
basis = PeriodicBasisGPU(**args)
else:
kw.update(Nxyz=Nxyz, Lxyz=Lxyz, symmetric_grid=symmetric_grid)
super().__init__(basis=basis, **kw)
[docs]
def init(self):
# Ensure that the metric lives on the GPU.
self._metric = self.xp.asarray(self.basis.metric)
super().init()
[docs]
def get_metric_GPU(self):
return self._metric