jax.lax.broadcast_shapes#

jax.lax.broadcast_shapes(*shapes: tuple[int, ...]) tuple[int, ...][source]#
jax.lax.broadcast_shapes(*shapes: tuple[int | core.Tracer, ...]) tuple[int | core.Tracer, ...]

Returns the shape that results from NumPy broadcasting of shapes.