jax.experimental.pallas.BlockSpec#
- class jax.experimental.pallas.BlockSpec(block_shape=None, index_map=None, *, memory_space=None, indexing_mode=Blocked)[source]#
Specifies how an array should be sliced for each invocation of a kernel.
See BlockSpec, a.k.a. how to chunk up inputs for more details.
- Parameters:
block_shape (Sequence[int | None] | None)
index_map (Callable[..., Any] | None)
memory_space (Any | None)
indexing_mode (IndexingMode)
- __init__(block_shape=None, index_map=None, *, memory_space=None, indexing_mode=Blocked)#
- Parameters:
block_shape (Sequence[int | None] | None | None)
index_map (Callable[..., Any] | None | None)
memory_space (Any | None | None)
indexing_mode (IndexingMode)
- Return type:
None
Methods
__init__([block_shape, index_map, ...])to_block_mapping(origin, array_aval, *, ...)Attributes
block_shapeindex_mapindexing_modememory_space