jax.numpy.fix#

jax.numpy.fix(x, out=None)[source]#

Round input to the nearest integer towards zero.

JAX implementation of numpy.fix().

Parameters:
  • x (ArrayLike) – input array.

  • out (None) – unused by JAX.

Returns:

An array with same shape and dtype as x containing the rounded values.

Return type:

Array

See also

Examples

>>> key = jax.random.key(0)
>>> x = jax.random.uniform(key, (3, 3), minval=-5, maxval=5)
>>> with jnp.printoptions(precision=2, suppress=True):
...     print(x)
[[-1.45  1.04 -0.72]
 [-2.69  1.74 -0.6 ]
 [-2.49 -2.23  2.68]]
>>> jnp.fix(x)
Array([[-1.,  1., -0.],
       [-2.,  1., -0.],
       [-2., -2.,  2.]], dtype=float32)