set_printoptions¶
- class tensordict.set_printoptions(**kwargs)¶
Controls which attributes appear in TensorDict’s
__repr__output.Can be used as a global setter (via
set()), a context manager, or a decorator. Follows the same pattern asset_lazy_legacy.- Keyword Arguments:
show_batch_size (bool, optional) – Show
batch_sizein TensorDict repr. Defaults toTrue.show_device (bool, optional) – Show
devicein TensorDict repr. Defaults toTrue.show_is_shared (bool, optional) – Show
is_sharedin TensorDict repr. Defaults toTrue.show_shape (bool, optional) – Show
shapein per-tensor field descriptors. Defaults toTrue.show_field_device (bool, optional) – Show
devicein per-tensor field descriptors. Defaults toTrue.show_dtype (bool, optional) – Show
dtypein per-tensor field descriptors. Defaults toTrue.show_field_is_shared (bool, optional) – Show
is_sharedin per-tensor field descriptors. Defaults toTrue.show_grad (bool, optional) – Show
requires_gradin per-tensor field descriptors. Defaults toFalse.show_is_contiguous (bool, optional) – Show
is_contiguousin per-tensor field descriptors. Defaults toFalse.show_is_view (bool, optional) – Show
is_viewin per-tensor field descriptors. Defaults toFalse.show_storage_size (bool, optional) – Show
storage_size(in bytes) in per-tensor field descriptors. Defaults toFalse.plain (bool, optional) – When
True, include a short summary of the actual tensor values in the field descriptors. Defaults toFalse.sort_keys (str or callable, optional) – Controls the order of keys in the repr.
"alphabetical"(default) sorts keys lexicographically."insertion"preserves the order in which keys were added. A callable is passed as thekeyargument tosorted().
Examples
>>> import torch >>> from tensordict import TensorDict, set_printoptions >>> td = TensorDict({"x": torch.randn(3, 4)}) >>> # Global >>> set_printoptions(show_device=False, show_is_shared=False).set() >>> print(td) >>> # Context manager >>> with set_printoptions(show_dtype=False): ... print(td) >>> # Decorator >>> @set_printoptions(show_is_shared=False) ... def my_func(td): ... print(td)