jax.experimental.pallas.run_scoped#
- jax.experimental.pallas.run_scoped(f, *types, **kw_types)[source]#
Calls the function with allocated references and returns the result.
The positional and keyword arguments describe which reference types to allocate for each argument. Each backend has its own set of reference types in addition to
jax.experimental.pallas.MemoryRef.- Parameters:
f (Callable[..., Any])
types (Any)
kw_types (Any)
- Return type:
Any