"""Updated bases with GPU support."""
import numpy as np
from mmfutils.math import bases, bessel
from mmfutils.math.bases.utils import get_xyz, get_kxyz
from . import cupy as cp, cupyx
[docs]
class PeriodicBasisGPU(bases.PeriodicBasis):
if cp:
[docs]
_fft = staticmethod(cp.fft.fft)
_ifft = staticmethod(cp.fft.ifft)
_fftn = staticmethod(cp.fft.fftn)
_ifftn = staticmethod(cp.fft.ifftn)
asnumpy = staticmethod(cp.asnumpy)
[docs]
class SphericalBasis(bases.SphericalBasis):
if cp:
_dst = staticmethod(cupyx.scipy.fft.dst)
_idst = staticmethod(cupyx.scipy.fft.idst)
else:
xp = np
_dst = staticmethod(bases.utils.dst)
_idst = staticmethod(bases.utils.idst)
[docs]
def init(self):
dx = self.R / self.N
r = self.xp.arange(1, self.N + 1) * dx
k = self.xp.pi * (0.5 + self.xp.arange(self.N)) / self.R
self.xyz = [r]
self._pxyz = [k]
self.metric = 4 * self.xp.pi * r**2 * dx
self.k_max = k.max()
[docs]
def laplacian(self, y, factor=None, factors=None, exp=False):
"""Return the laplacian of `y` times `factor` or the
exponential of this.
Arguments
---------
factor : float | None
Additional factor(s) (mostly used with `exp=True`). The
implementation must be careful to allow the factor to
broadcast across the components.
factors : [float] | None
Tuple of scale factors for each dimension. Allows for independent scaling
of each direction (used in expanding reference frames).
exp : bool
If `True`, then compute the exponential of the laplacian.
This is used for split evolvers.
"""
r = self.xyz[0]
if factors is None:
pxyz = self._pxyz
else:
pxyz = [f * p for f, p in zip(factors, self._pxyz)]
if factor is None:
sign = -1
else:
sign = -factor
K = sign * pxyz[0] ** 2
if exp:
K = self.xp.exp(K)
ys = [y.real, y.imag] if self.xp.iscomplexobj(y) else [y]
res = [self._idst(K * self._dst(r * _y)) / r for _y in ys]
if self.xp.iscomplexobj(y):
res = res[0] + 1j * res[1]
else:
res = res[0]
return res
[docs]
def coulomb_kernel(self, k, factors=None):
"""Form for the truncated Coulomb kernel."""
if factors is not None:
raise NotImplementedError("Convolution with {factors=} not yet supported")
D = 2 * self.R
with np.errstate(divide="ignore", invalid="ignore"):
return (
4
* self.xp.pi
* self.xp.where(k == 0, D**2 / 2.0, (1.0 - self.xp.cos(k * D)) / k**2)
)
[docs]
def convolve_coulomb(self, y, form_factors=(), factors=None):
"""Modified Coulomb convolution to include form-factors (if provided).
This version implemented a 3D spherically symmetric convolution.
"""
if factors is not None:
raise NotImplementedError("Convolution with {factors=} not yet supported")
y = self.xp.asarray(y)
r = self.xyz[0]
N, R = self.N, self.R
# Padded arrays with trailing _
ry_ = self.xp.concatenate([r * y, self.xp.zeros(y.shape, dtype=y.dtype)], axis=-1)
k_ = self.xp.pi * (0.5 + self.xp.arange(2 * N)) / (2 * R)
K = prod([_K(k_) for _K in (self.coulomb_kernel,) + tuple(form_factors)])
return self._idst(K * self._dst(ry_))[..., :N] / r
[docs]
def convolve(self, y, C=None, Ck=None, factors=None):
"""Return the periodic convolution `int(C(x-r)*y(r),r)`.
Note: this is the 3D convolution.
"""
if factors is not None:
raise NotImplementedError("Convolution with {factors=} not yet supported")
r = self.xyz[0]
k = self._pxyz[0]
N, R = self.N, self.R
R_N = R / N
if Ck is None:
C0 = (self.metric * C).sum()
with np.errstate(divide="ignore", invalid="ignore"):
Ck = self.xp.where(
k == 0, C0, 2 * self.xp.pi * R_N * self._dst(r * C) / k
)
else:
Ck = Ck(k)
return self._idst(Ck * self._dst(r * y)) / r
[docs]
class CylindricalBasis(bases.CylindricalBasis):
if cp:
_dst = staticmethod(cupyx.scipy.fft.dst)
_idst = staticmethod(cupyx.scipy.fft.idst)
_fft = staticmethod(cp.fft.fft)
_ifft = staticmethod(cp.fft.ifft)
_fftn = staticmethod(cp.fft.fftn)
_ifftn = staticmethod(cp.fft.ifftn)
asnumpy = staticmethod(cp.asnumpy)
else:
xp = np
_dst = staticmethod(bases.utils.dst)
_idst = staticmethod(bases.utils.idst)
_fft = staticmethod(bases.bases.fft)
_ifft = staticmethod(bases.bases.ifft)
_fftn = staticmethod(bases.bases.fftn)
_ifftn = staticmethod(bases.bases.ifftn)
asnumpy = staticmethod(np.asarray)
[docs]
def init(self):
super().init()
self.kx = self.xp.asarray(self.kx)
self._kx2 = self.xp.asarray(self._kx2)
self.xyz = tuple(map(self.xp.asarray, self.xyz))
self.metric = self.xp.asarray(self.metric)
self.weights = self.xp.asarray(self.weights)
self._Kr = self.xp.asarray(self._Kr)
self._Kr_diag = tuple(map(self.xp.asarray, self._Kr_diag))
self._Kx = self.xp.asarray(self._Kx)
######################################################################
[docs]
def apply_exp_K(self, y, factor, factors=None, kx2=None, **_kw):
r"""Return `exp(K*factor)*y` or return precomputed data if
`K_data` is `None`.
"""
assert bases.bases._raise_twist_err(self, _kw, name="apply_exp_K")
assert bases.bases._raise_factors_err(factors, kx2=kx2)
if factor is None:
factor = 1.0
if factors is not None:
factor_x = factors[0] ** 2 * factor
factor_r = factors[1] ** 2 * factor
else:
factor_x = factor_r = factor
if kx2 is None:
kx2 = self._Kx
# Check if we have computed this already in the _K_data cache.
_K_data_max_len = 3
ind = None
for _i, (key, _d) in enumerate(self._K_data):
if self.xp.allclose((factor_x, factor_r), key):
ind = _i
if ind is None: # If not, compute
_r1, _r2, V, d = self._Kr_diag
exp_K_r = _r1 * self.xp.dot(V * np.exp(factor_r * d), V.T) * _r2
exp_K_x = self.xp.exp(factor_x * kx2)
K_data = (exp_K_r, exp_K_x)
key = (factor_x, factor_r)
self._K_data.append((key, K_data))
ind = -1
while len(self._K_data) > _K_data_max_len:
# Reduce storage
self._K_data.pop(0)
K_data = self._K_data[ind][1]
exp_K_r, exp_K_x = K_data
tmp = self.ifft(exp_K_x * self.fft(y))
return self.xp.einsum("...ij,...yj->...yi", exp_K_r, tmp)
[docs]
def apply_K(self, y, factors=None, kx2=None, **_kw):
r"""Return `K*y` where `K = k**2/2`"""
# Here is how the indices work:
assert bases.bases._raise_twist_err(self, _kw, name="apply_K")
assert bases.bases._raise_factors_err(factors, kx2=kx2)
if kx2 is None:
kx2 = self._Kx
if factors is not None:
kx2 = kx2 * factors[0] ** 2
yt = self.fft(y)
yt *= kx2
yt = self.ifft(yt)
# C <- alpha*B*A + beta*C A = A^T zSYMM or zHYMM but not supported
# maybe cvxopt.blas? Actually, A is not symmetric... so be careful!
if factors is None:
yt += self.xp.dot(y, self._Kr.T)
else:
yt += self.xp.dot(y, self._Kr.T) * factors[1] ** 2
return yt
######################################################################
# FFT and DVR Helper functions.
#
# These are specific to the basis, defining the kinetic energy
# matrix for example.
# We need these wrappers because the state may have additional
# indices for components etc. in front.
[docs]
def fft(self, x):
"""Perform the fft along the x axes"""
# Makes sure that
axis = (self.axes % len(x.shape))[0]
return self._fft(x, axis=axis)
[docs]
def ifft(self, x):
"""Perform the fft along the x axes"""
axis = (self.axes % len(x.shape))[0]
return self._ifft(x, axis=axis)
[docs]
def integrate1(self, n):
"""Return the integral of n over y and z."""
n = self.xp.asarray(n)
x, r = self.xyz
x_axis, r_axis = self.axes
bcast = [None] * len(n.shape)
bcast[x_axis] = slice(None)
bcast[r_axis] = slice(None)
return ((2 * self.xp.pi * r * self.weights)[tuple(bcast)] * n).sum(axis=r_axis)
[docs]
def integrate2(self, n, y=None, Nz=100):
"""Return the integral of n over z (line-of-sight integral) at y.
This is an Abel transform, and is used to compute the 1D
line-of-sight integral as would be seen by a photographic
image through an axial cloud.
Arguments
---------
n : array
(Nx, Nr) array of the function to be integrated tabulated
on the abscissa. Note: the extrapolation assumes that `n =
abs(psi)**2` where `psi` is well represented in the basis.
y : array, None
Ny points at which the resulting integral should be
returned. If not provided, then the function will be
tabulated at the radial abscissa.
Nz : int
Number of points to use in z integral.
"""
n = self.xp.asarray(n)
x, r = self.xyz
if y is None:
y = r
y = y.ravel()
Ny = len(y)
x_axis, r_axis = self.axes
y_axis = r_axis
bcast_y = [None] * len(n.shape)
bcast_z = [None] * len(n.shape)
bcast_y[y_axis] = slice(None)
bcast_y.append(None)
bcast_z.append(slice(None))
bcast_y, bcast_z = tuple(bcast_y), tuple(bcast_z)
z = self.xp.linspace(0, r.max(), Nz)
shape_xyz = n.shape[:-1] + (Ny, Nz)
rs = self.xp.sqrt(y.ravel()[bcast_y] ** 2 + z[bcast_z] ** 2)
n_xyz = (abs(self.Psi(self.xp.sqrt(n), (x, rs.ravel()))) ** 2).reshape(shape_xyz)
n_2D = 2 * self.xp.trapezoid(n_xyz, z, axis=-1)
return n_2D