Source code for gpe.gpu.bases

"""Updated bases with GPU support."""

import numpy as np

from mmfutils.math import bases, bessel
from mmfutils.math.bases.utils import get_xyz, get_kxyz


from . import cupy as cp, cupyx


[docs] class PeriodicBasisGPU(bases.PeriodicBasis):
[docs] xp = cp
if cp:
[docs] _fft = staticmethod(cp.fft.fft)
_ifft = staticmethod(cp.fft.ifft) _fftn = staticmethod(cp.fft.fftn) _ifftn = staticmethod(cp.fft.ifftn) asnumpy = staticmethod(cp.asnumpy)
[docs] class SphericalBasis(bases.SphericalBasis): if cp:
[docs] xp = cp
_dst = staticmethod(cupyx.scipy.fft.dst) _idst = staticmethod(cupyx.scipy.fft.idst) else: xp = np _dst = staticmethod(bases.utils.dst) _idst = staticmethod(bases.utils.idst)
[docs] def init(self): dx = self.R / self.N r = self.xp.arange(1, self.N + 1) * dx k = self.xp.pi * (0.5 + self.xp.arange(self.N)) / self.R self.xyz = [r] self._pxyz = [k] self.metric = 4 * self.xp.pi * r**2 * dx self.k_max = k.max()
[docs] def laplacian(self, y, factor=None, factors=None, exp=False): """Return the laplacian of `y` times `factor` or the exponential of this. Arguments --------- factor : float | None Additional factor(s) (mostly used with `exp=True`). The implementation must be careful to allow the factor to broadcast across the components. factors : [float] | None Tuple of scale factors for each dimension. Allows for independent scaling of each direction (used in expanding reference frames). exp : bool If `True`, then compute the exponential of the laplacian. This is used for split evolvers. """ r = self.xyz[0] if factors is None: pxyz = self._pxyz else: pxyz = [f * p for f, p in zip(factors, self._pxyz)] if factor is None: sign = -1 else: sign = -factor K = sign * pxyz[0] ** 2 if exp: K = self.xp.exp(K) ys = [y.real, y.imag] if self.xp.iscomplexobj(y) else [y] res = [self._idst(K * self._dst(r * _y)) / r for _y in ys] if self.xp.iscomplexobj(y): res = res[0] + 1j * res[1] else: res = res[0] return res
[docs] def coulomb_kernel(self, k, factors=None): """Form for the truncated Coulomb kernel.""" if factors is not None: raise NotImplementedError("Convolution with {factors=} not yet supported") D = 2 * self.R with np.errstate(divide="ignore", invalid="ignore"): return ( 4 * self.xp.pi * self.xp.where(k == 0, D**2 / 2.0, (1.0 - self.xp.cos(k * D)) / k**2) )
[docs] def convolve_coulomb(self, y, form_factors=(), factors=None): """Modified Coulomb convolution to include form-factors (if provided). This version implemented a 3D spherically symmetric convolution. """ if factors is not None: raise NotImplementedError("Convolution with {factors=} not yet supported") y = self.xp.asarray(y) r = self.xyz[0] N, R = self.N, self.R # Padded arrays with trailing _ ry_ = self.xp.concatenate([r * y, self.xp.zeros(y.shape, dtype=y.dtype)], axis=-1) k_ = self.xp.pi * (0.5 + self.xp.arange(2 * N)) / (2 * R) K = prod([_K(k_) for _K in (self.coulomb_kernel,) + tuple(form_factors)]) return self._idst(K * self._dst(ry_))[..., :N] / r
[docs] def convolve(self, y, C=None, Ck=None, factors=None): """Return the periodic convolution `int(C(x-r)*y(r),r)`. Note: this is the 3D convolution. """ if factors is not None: raise NotImplementedError("Convolution with {factors=} not yet supported") r = self.xyz[0] k = self._pxyz[0] N, R = self.N, self.R R_N = R / N if Ck is None: C0 = (self.metric * C).sum() with np.errstate(divide="ignore", invalid="ignore"): Ck = self.xp.where( k == 0, C0, 2 * self.xp.pi * R_N * self._dst(r * C) / k ) else: Ck = Ck(k) return self._idst(Ck * self._dst(r * y)) / r
[docs] class CylindricalBasis(bases.CylindricalBasis): if cp:
[docs] xp = cp
_dst = staticmethod(cupyx.scipy.fft.dst) _idst = staticmethod(cupyx.scipy.fft.idst) _fft = staticmethod(cp.fft.fft) _ifft = staticmethod(cp.fft.ifft) _fftn = staticmethod(cp.fft.fftn) _ifftn = staticmethod(cp.fft.ifftn) asnumpy = staticmethod(cp.asnumpy) else: xp = np _dst = staticmethod(bases.utils.dst) _idst = staticmethod(bases.utils.idst) _fft = staticmethod(bases.bases.fft) _ifft = staticmethod(bases.bases.ifft) _fftn = staticmethod(bases.bases.fftn) _ifftn = staticmethod(bases.bases.ifftn) asnumpy = staticmethod(np.asarray)
[docs] def init(self): super().init() self.kx = self.xp.asarray(self.kx) self._kx2 = self.xp.asarray(self._kx2) self.xyz = tuple(map(self.xp.asarray, self.xyz)) self.metric = self.xp.asarray(self.metric) self.weights = self.xp.asarray(self.weights) self._Kr = self.xp.asarray(self._Kr) self._Kr_diag = tuple(map(self.xp.asarray, self._Kr_diag)) self._Kx = self.xp.asarray(self._Kx)
######################################################################
[docs] def apply_exp_K(self, y, factor, factors=None, kx2=None, **_kw): r"""Return `exp(K*factor)*y` or return precomputed data if `K_data` is `None`. """ assert bases.bases._raise_twist_err(self, _kw, name="apply_exp_K") assert bases.bases._raise_factors_err(factors, kx2=kx2) if factor is None: factor = 1.0 if factors is not None: factor_x = factors[0] ** 2 * factor factor_r = factors[1] ** 2 * factor else: factor_x = factor_r = factor if kx2 is None: kx2 = self._Kx # Check if we have computed this already in the _K_data cache. _K_data_max_len = 3 ind = None for _i, (key, _d) in enumerate(self._K_data): if self.xp.allclose((factor_x, factor_r), key): ind = _i if ind is None: # If not, compute _r1, _r2, V, d = self._Kr_diag exp_K_r = _r1 * self.xp.dot(V * np.exp(factor_r * d), V.T) * _r2 exp_K_x = self.xp.exp(factor_x * kx2) K_data = (exp_K_r, exp_K_x) key = (factor_x, factor_r) self._K_data.append((key, K_data)) ind = -1 while len(self._K_data) > _K_data_max_len: # Reduce storage self._K_data.pop(0) K_data = self._K_data[ind][1] exp_K_r, exp_K_x = K_data tmp = self.ifft(exp_K_x * self.fft(y)) return self.xp.einsum("...ij,...yj->...yi", exp_K_r, tmp)
[docs] def apply_K(self, y, factors=None, kx2=None, **_kw): r"""Return `K*y` where `K = k**2/2`""" # Here is how the indices work: assert bases.bases._raise_twist_err(self, _kw, name="apply_K") assert bases.bases._raise_factors_err(factors, kx2=kx2) if kx2 is None: kx2 = self._Kx if factors is not None: kx2 = kx2 * factors[0] ** 2 yt = self.fft(y) yt *= kx2 yt = self.ifft(yt) # C <- alpha*B*A + beta*C A = A^T zSYMM or zHYMM but not supported # maybe cvxopt.blas? Actually, A is not symmetric... so be careful! if factors is None: yt += self.xp.dot(y, self._Kr.T) else: yt += self.xp.dot(y, self._Kr.T) * factors[1] ** 2 return yt
###################################################################### # FFT and DVR Helper functions. # # These are specific to the basis, defining the kinetic energy # matrix for example. # We need these wrappers because the state may have additional # indices for components etc. in front.
[docs] def fft(self, x): """Perform the fft along the x axes""" # Makes sure that axis = (self.axes % len(x.shape))[0] return self._fft(x, axis=axis)
[docs] def ifft(self, x): """Perform the fft along the x axes""" axis = (self.axes % len(x.shape))[0] return self._ifft(x, axis=axis)
[docs] def integrate1(self, n): """Return the integral of n over y and z.""" n = self.xp.asarray(n) x, r = self.xyz x_axis, r_axis = self.axes bcast = [None] * len(n.shape) bcast[x_axis] = slice(None) bcast[r_axis] = slice(None) return ((2 * self.xp.pi * r * self.weights)[tuple(bcast)] * n).sum(axis=r_axis)
[docs] def integrate2(self, n, y=None, Nz=100): """Return the integral of n over z (line-of-sight integral) at y. This is an Abel transform, and is used to compute the 1D line-of-sight integral as would be seen by a photographic image through an axial cloud. Arguments --------- n : array (Nx, Nr) array of the function to be integrated tabulated on the abscissa. Note: the extrapolation assumes that `n = abs(psi)**2` where `psi` is well represented in the basis. y : array, None Ny points at which the resulting integral should be returned. If not provided, then the function will be tabulated at the radial abscissa. Nz : int Number of points to use in z integral. """ n = self.xp.asarray(n) x, r = self.xyz if y is None: y = r y = y.ravel() Ny = len(y) x_axis, r_axis = self.axes y_axis = r_axis bcast_y = [None] * len(n.shape) bcast_z = [None] * len(n.shape) bcast_y[y_axis] = slice(None) bcast_y.append(None) bcast_z.append(slice(None)) bcast_y, bcast_z = tuple(bcast_y), tuple(bcast_z) z = self.xp.linspace(0, r.max(), Nz) shape_xyz = n.shape[:-1] + (Ny, Nz) rs = self.xp.sqrt(y.ravel()[bcast_y] ** 2 + z[bcast_z] ** 2) n_xyz = (abs(self.Psi(self.xp.sqrt(n), (x, rs.ravel()))) ** 2).reshape(shape_xyz) n_2D = 2 * self.xp.trapezoid(n_xyz, z, axis=-1) return n_2D