jax.experimental.pallas.run_scoped#

jax.experimental.pallas.run_scoped(f, *types, **kw_types)[source]#

Calls the function with allocated references and returns the result.

The positional and keyword arguments describe which reference types to allocate for each argument. Each backend has its own set of reference types in addition to jax.experimental.pallas.MemoryRef.

Parameters:
  • f (Callable[..., Any])

  • types (Any)

  • kw_types (Any)

Return type:

Any