jax.tree_util.keystr#
- jax.tree_util.keystr(keys)[source]#
Helper to pretty-print a tuple of keys.
- Parameters:
keys (KeyPath) – A tuple of
KeyEntryor any class that can be converted to string.- Returns:
A string that joins all string representations of the keys.
Examples
>>> import jax >>> keys = (0, 1, 'a', 'b') >>> jax.tree_util.keystr(keys) '01ab'