"""Chebyshev methods."""
import inspect
import numpy as np
import scipy as sp
import sympy
import matplotlib.pyplot as plt
[docs]
_EPS = np.finfo(float).eps
[docs]
def sinnx_sinx(x, n, d=0):
"""Return the dth derivative of sin(n*x)/sin(x)."""
s = np.sin(x)
sn = np.sin(n * x)
c = np.cos(x)
cn = np.cos(n * x)
if d == 0:
return np.where(abs(s) < _EPS, n * cn / c, sn / s)
if d == 1:
(n * cn * s - sn * c) / s**2
(n * cn * s - sn * c) / s**2
[docs]
class Base:
"""Simple base class.
Allows class variables to be assigned in the constructor and then delegates to
init() for initialization. See {class}`gpe.utils.ExperimentBase` for a more
comprehensive example.
"""
def __init__(self, **kw):
[docs]
self._initializing = True
for key in kw:
if not hasattr(self, key):
raise ValueError(f"Unknown {key=}")
setattr(self, key, kw[key])
self.init()
self._initializing = False
[docs]
def __setattr__(self, key, value):
super().__setattr__(key, value)
if self._initializing or key.startswith("_"):
return
# Reinitialize if attributes change
self._initializing = True
self.init()
self._initializing = False
[docs]
def init(self):
"""Perform any initializations here."""
pass
[docs]
def __repr__(self):
"""Return a rudimentary representation for debugging."""
return f"{self.__class__.__qualname__}({{}})".format(
", ".join(
[
f"{n}={v}"
for n, v in inspect.getmembers(self, lambda f: not callable(f))
if not n.startswith("_")
]
)
)
[docs]
class Chebyshev(Base):
r"""Chebyshev pseudo-spectral basis of order N.
This is an orthonormal basis for functions on [-1, 1]:
.. math::
\braket{f, g} = \int_{-1}^{1} \frac{f^*(x)g(x)}{\sqrt{1-x^2}}\d{x}
Parameters
----------
N : int
Order of the basis. The quadrature rule is exact for polynomials of order 2N-1
(i.e. containing 2N terms $1, x, x^2, ..., x^{2N-1}$). This means that there are
`N` interior abscissa or `N+1` abscissa including the endpoints (`interior = False`).
interior : bool
Type of collocation points. If `True`, then `N` interior abscissa will be used,
the roots of `T_N(x)`, also called the Chebyshev-Gauss quadature. If `False`,
then `N+1` abscissa will be used, including the endpoints (extreme + endpoints),
also called the Chebyshev-Gauss-Lobatto quadrature.
x, w : array_like
Abscissa and corresponding weights.
"""
[docs]
def init(self):
if self.xL != -1.0 or self.xR != 1.0:
raise ValueError(
f"{self.__class__.__name__} only supports (xL, xR) = (-1, 1) "
+ f"(got {(self.xL, self.xR)})"
)
if self.interior:
Nx = self.N
dk = 0.5
else:
Nx = self.N + 1
dk = 0.0
self._k = np.arange(Nx)
self._theta = np.pi * (self._k[::-1] + dk) / self.N
self.x = np.cos(self._theta)
self.w = np.pi / self.N * np.ones_like(self.x)
if not self.interior:
self.w[0] /= 2
self.w[-1] /= 2
super().init()
@staticmethod
[docs]
def W(x):
"""Return the integration weight (metric)."""
return 1 / np.sqrt(1 - x**2)
@staticmethod
[docs]
def Tn(x, n, d=0):
"""Return the dth derivative of the nth Chebyshev polynomial.
These are orthogonal, but not normalized.
"""
th = np.arccos(x)
if d == 0:
return np.cos(n * th)
elif d == 1:
return n * sinnx_sinx(n, th)
[docs]
def Cn(self, x, n):
"""Return the nth Cardinal function: zero on all abscissa except x_n."""
if self.interior:
return (x - self.x[n]) * self.Tn(x, n=n) / self.Tn(x, n=n, d=1)
else:
pass
[docs]
def get_a(self, f):
"""Return the Chebyshev series using the DCT."""
if self.interior:
ft = sp.fft.dct(f[::-1], type=2, norm="forward")
a = 2 * ft
a[0] = ft[0]
else:
ft = sp.fft.idct(f[::-1], type=1, norm="backward")
a = 2 * ft
a[0], a[-1] = ft[0], ft[-1]
return a
[docs]
def get_f(self, a):
"""Evaluate the series at the abscissa using the DCT.
This is the inverse of `get_a()`.
"""
# Pad or truncate if needed. Allows interpolation on different sized grids.
if len(a) > len(self.x):
warnings.warn(f"Truncating Chebyshev series {len(a)} -> {len(self.x)}")
a_ = np.zeros_like(a, shape=self.x.shape)
n = min(len(a), len(a_))
a_[:n] = a[:n]
ft = a_ / 2
if self.interior:
ft[0] = a_[0]
return sp.fft.idct(ft, type=2, norm="forward")[::-1]
else:
ft[0], ft[-1] = a_[0], a_[-1]
return sp.fft.dct(ft, type=1, norm="backward")[::-1]
[docs]
def diff(self, f, d=1):
"""Return the dth derivative of f."""
if d == 0:
return f
k = self._k
s = np.sin(self._theta)
if self.interior:
ft = sp.fft.dct(f, type=2, norm="forward")
# Shift the coefficients and pad with zero.
df_dth = sp.fft.dst(np.concatenate([-(ft * k)[1:], [0]]), type=3)
if d == 1:
df_dx = df_dth / s
return df_dx
elif d >= 2:
d2f_dth2 = sp.fft.idct(-ft * k**2, type=2, norm="forward")
d2f_dx2 = (d2f_dth2 + df_dth / np.tan(self._theta)) / s**2
if d == 2:
return d2f_dx2
return self.diff(d2f_dx2, d=d - 2)
else:
ft = sp.fft.idct(f, type=1, norm="backward")
a = 2 * ft
a[0] /= 2
a[-1] /= 2 # TODO: Tests do not pickup on changes here!
# Shift the coefficients - these are the interior points
if len(ft) > 2:
df_dth_ = sp.fft.dst(-(ft * k)[1:-1], type=1)
else:
df_dth_ = np.array([])
if d == 1:
ak2 = -a * k**2
ak2m = (-1) ** (k + 1) * ak2
df_dx = np.concatenate([[sum(ak2)], df_dth_ / s[1:-1], [sum(ak2m)]])
return df_dx
elif d >= 2:
d2f_dth2_ = sp.fft.dct(-ft * k**2, type=1, norm="backward")[1:-1]
ak = a * k**2 * (1 - k**2) / 3
akm = (-1) ** (k) * ak
d2f_dx2 = np.concatenate(
[
[-sum(ak)],
(d2f_dth2_ + df_dth_ / np.tan(self._theta[1:-1])) / s[1:-1] ** 2,
[-sum(akm)],
]
)
if d == 2:
return d2f_dx2
return self.diff(d2f_dx2, d=d - 2)
[docs]
def integrate(self, f):
"""Return the integral of f."""
return (f * self.w).sum()