jax.scipy.fft.dct#
- jax.scipy.fft.dct(x, type=2, n=None, axis=-1, norm=None)[source]#
Computes the discrete cosine transform of the input
JAX implementation of
scipy.fft.dct().- Parameters:
x (Array) – array
type (int) – integer, default = 2. Currently only type 2 is supported.
n (int | None | None) – integer, default = x.shape[axis]. The length of the transform. If larger than
x.shape[axis], the input will be zero-padded, if smaller, the input will be truncated.axis (int) – integer, default=-1. The axis along which the dct will be performed.
norm (str | None | None) – string. The normalization mode: one of
[None, "backward", "ortho"]. The default isNone, which is equivalent to"backward".
- Returns:
array containing the discrete cosine transform of x
- Return type:
See also
jax.scipy.fft.dctn(): multidimensional DCTjax.scipy.fft.idct(): inverse DCTjax.scipy.fft.idctn(): multidimensional inverse DCT
Examples
>>> x = jax.random.normal(jax.random.key(0), (3, 3)) >>> with jnp.printoptions(precision=2, suppress=True): ... print(jax.scipy.fft.dct(x)) [[-0.58 -0.33 -1.08] [-0.88 -1.01 -1.79] [-1.06 -2.43 1.24]]
When
nsmaller thanx.shape[axis]>>> with jnp.printoptions(precision=2, suppress=True): ... print(jax.scipy.fft.dct(x, n=2)) [[-0.22 -0.9 ] [-0.57 -1.68] [-2.52 -0.11]]
When
nsmaller thanx.shape[axis]andaxis=0>>> with jnp.printoptions(precision=2, suppress=True): ... print(jax.scipy.fft.dct(x, n=2, axis=0)) [[-2.22 1.43 -0.67] [ 0.52 -0.26 -0.04]]
When
nlarger thanx.shape[axis]andaxis=1>>> with jnp.printoptions(precision=2, suppress=True): ... print(jax.scipy.fft.dct(x, n=4, axis=1)) [[-0.58 -0.35 -0.64 -1.11] [-0.88 -0.9 -1.46 -1.68] [-1.06 -2.25 -1.15 1.93]]