jax.export.export#

jax.export.export(fun_jit, *, platforms=None, disabled_checks=())[source]#

Exports a JAX function for persistent serialization.

Parameters:
  • fun_jit (stages.Wrapped) – the function to export. Should be the result of jax.jit.

  • platforms (Sequence[str] | None | None) – Optional sequence containing a subset of ‘tpu’, ‘cpu’, ‘cuda’, ‘rocm’. If more than one platform is specified, then the exported code takes an argument specifying the platform. If None, then use the default JAX backend. The calling convention for multiple platforms is explained at https://jax.readthedocs.io/en/latest/export/export.html#module-calling-convention.

  • disabled_checks (Sequence[DisabledSafetyCheck]) – the safety checks to disable. See documentation for of jax.export.DisabledSafetyCheck.

Returns:

a function that takes args and kwargs pytrees of {class}`jax.ShapeDtypeStruct`, or values with .shape and .dtype attributes, and returns an Exported.

Return type:

Callable[…, Exported]

Usage:

>>> from jax import export
>>> exported: export.Exported = export.export(jnp.sin)(
...     np.arange(4, dtype=np.float32))
>>>
>>> # You can inspect the Exported object
>>> exported.in_avals
(ShapedArray(float32[4]),)
>>> blob: bytearray = exported.serialize()
>>>
>>> # The serialized bytes are safe to use in a separate process
>>> rehydrated: export.Exported = export.deserialize(blob)
>>> rehydrated.fun_name
'sin'
>>> rehydrated.call(np.array([.1, .2, .3, .4], dtype=np.float32))
Array([0.09983342, 0.19866933, 0.29552022, 0.38941833], dtype=float32)