# Debugging runtime values

<!--* freshness: { reviewed: '2024-04-11' } *-->

Do you have exploding gradients? Are NaNs making you gnash your teeth? Just want to poke around the intermediate values in your computation? Check out the following JAX debugging tools! This page has summaries and you can click the "Read more" links at the bottom to learn more.

Table of contents:

* [Interactive inspection with `jax.debug`](print_breakpoint)
* [Functional error checks with jax.experimental.checkify](checkify_guide)
* [Throwing Python errors with JAX’s debug flags](flags)

## Interactive inspection with `jax.debug`

Complete guide [here](print_breakpoint)

  **Summary:** Use {func}`jax.debug.print` to print values to stdout in `jax.jit`-,`jax.pmap`-, and `pjit`-decorated functions,
  and {func}`jax.debug.breakpoint` to pause execution of your compiled function to inspect values in the call stack:

  ```python
  import jax
  import jax.numpy as jnp

  @jax.jit
  def f(x):
    jax.debug.print("🤯 {x} 🤯", x=x)
    y = jnp.sin(x)
    jax.debug.breakpoint()
    jax.debug.print("🤯 {y} 🤯", y=y)
    return y

  f(2.)
# Prints:
# 🤯 2.0 🤯
# Enters breakpoint to inspect values!
# 🤯 0.9092974662780762 🤯
  ```

[Read more](print_breakpoint).

## Functional error checks with `jax.experimental.checkify`

Complete guide [here](checkify_guide)

  **Summary:** Checkify lets you add `jit`-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the `checkify.checkify` transformation together with the assert-like `checkify.check` function to add runtime checks to JAX code:

  ```python
  from jax.experimental import checkify
  import jax
  import jax.numpy as jnp

  def f(x, i):
    checkify.check(i >= 0, "index needs to be non-negative!")
    y = x[i]
    z = jnp.sin(y)
    return z

  jittable_f = checkify.checkify(f)

  err, z = jax.jit(jittable_f)(jnp.ones((5,)), -1)
  print(err.get())
# >> index needs to be non-negative! (check failed at <...>:6 (f))
  ```

  You can also use checkify to automatically add common checks:

  ```python
  errors = checkify.user_checks | checkify.index_checks | checkify.float_checks
  checked_f = checkify.checkify(f, errors=errors)

  err, z = checked_f(jnp.ones((5,)), 100)
  err.throw()
# ValueError: out-of-bounds indexing at <..>:7 (f)

  err, z = checked_f(jnp.ones((5,)), -1)
  err.throw()
# ValueError: index needs to be non-negative! (check failed at <…>:6 (f))

  err, z = checked_f(jnp.array([jnp.inf, 1]), 0)
  err.throw()
# ValueError: nan generated by primitive sin at <...>:8 (f)
  ```

[Read more](checkify_guide).

## Throwing Python errors with JAX's debug flags

Complete guide [here](flags)

**Summary:** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code) and enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb`.

```python
import jax
jax.config.update("jax_debug_nans", True)

def f(x, y):
  return x / y
jax.jit(f)(0., 0.)  # ==> raises FloatingPointError exception!
```

[Read more](flags).

```{toctree}
:caption: Read more
:maxdepth: 1

print_breakpoint
checkify_guide
flags
```

