jax.scipy.fft.dct#

jax.scipy.fft.dct(x, type=2, n=None, axis=-1, norm=None)[source]#

Computes the discrete cosine transform of the input

JAX implementation of scipy.fft.dct().

Parameters:
  • x (Array) – array

  • type (int) – integer, default = 2. Currently only type 2 is supported.

  • n (int | None | None) – integer, default = x.shape[axis]. The length of the transform. If larger than x.shape[axis], the input will be zero-padded, if smaller, the input will be truncated.

  • axis (int) – integer, default=-1. The axis along which the dct will be performed.

  • norm (str | None | None) – string. The normalization mode: one of [None, "backward", "ortho"]. The default is None, which is equivalent to "backward".

Returns:

array containing the discrete cosine transform of x

Return type:

Array

See also

Examples

>>> x = jax.random.normal(jax.random.key(0), (3, 3))
>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jax.scipy.fft.dct(x))
[[-0.58 -0.33 -1.08]
 [-0.88 -1.01 -1.79]
 [-1.06 -2.43  1.24]]

When n smaller than x.shape[axis]

>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jax.scipy.fft.dct(x, n=2))
[[-0.22 -0.9 ]
 [-0.57 -1.68]
 [-2.52 -0.11]]

When n smaller than x.shape[axis] and axis=0

>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jax.scipy.fft.dct(x, n=2, axis=0))
[[-2.22  1.43 -0.67]
 [ 0.52 -0.26 -0.04]]

When n larger than x.shape[axis] and axis=1

>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jax.scipy.fft.dct(x, n=4, axis=1))
[[-0.58 -0.35 -0.64 -1.11]
 [-0.88 -0.9  -1.46 -1.68]
 [-1.06 -2.25 -1.15  1.93]]