jax.experimental.pallas.debug_print#
- jax.experimental.pallas.debug_print(fmt, *args)[source]#
Prints values from inside a Pallas kernel.
- Parameters:
fmt (str) –
A format string to be included in the output. The restrictions on the format string depend on the backend:
On GPU, when using Triton,
fmtmust not contain any placeholders ({...}), since it is always printed before any of the values.On GPU, when using the experimental Mosaic GPU backend,
fmtmust contain a placeholder for each value to be printed. Format specs and conversions are not supported. All values must be scalars.In TPU, if
fmtcontains placeholders, all values must be 32-bit integers. If there are no placeholders, the values are printed after the format string. All values must be scalars.
*args (jax.typing.ArrayLike) – The values to print.