Shortcuts

TensorClass

class tensordict.TensorClass(*args, **kwargs)

TensorClass is the inheritance-based version of the tensorclass() decorator.

TensorClass lets you write dataclass-like containers that benefit from all the TensorDict machinery (indexing, reshaping, to(device), stack/cat, memmap serialization, …) while remaining pythonic and friendly to static type-checkers. It is the recommended entry point for new code; the tensorclass() decorator is kept for backwards compatibility and dataclass-style ergonomics.

Examples

>>> from typing import Any
>>> import torch
>>> from tensordict import TensorClass
>>> class Foo(TensorClass):
...     tensor: torch.Tensor
...     non_tensor: Any
...     nested: Any = None
>>> foo = Foo(tensor=torch.randn(3), non_tensor="a string!", nested=None, batch_size=[3])
>>> print(foo)
Foo(
    non_tensor=NonTensorData(data=a string!, batch_size=torch.Size([3]), device=None),
    tensor=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
    nested=None,
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
Keyword Arguments:
  • batch_size (torch.Size, optional) – The batch size of the TensorDict. Defaults to None.

  • device (torch.device, optional) – The device on which the TensorDict will be created. Defaults to None.

  • frozen (bool, optional) – If True, the resulting class or instance will be immutable. Defaults to False.

  • autocast (bool, optional) – If True, values are coerced to the field’s type annotation when set (e.g. an int annotation keeps Python ints as ints rather than wrapping them in a tensor). Mutually exclusive with nocast and tensor_only. Defaults to False.

  • nocast (bool, optional) – If True, tensor-compatible scalar types (int, float, np.ndarray, …) are stored as-is instead of being cast to a tensor. Mutually exclusive with autocast and tensor_only. Defaults to False.

  • tensor_only (bool, optional) – If True, every field is expected to hold a tensor (or tensor-compatible value, which will be cast). Lookups skip the non-tensor data path, which can yield significant speed-ups at the cost of losing non-tensor support. Mutually exclusive with autocast and nocast. Defaults to False.

  • shadow (bool, optional) – Disables the validation of field names against TensorDict’s reserved attributes (e.g. allowing a field named device or batch_size). Use with caution, as this can lead to surprising behaviour. Defaults to False.

The bracket form TensorClass[...] is sugar for “build a parametrized subclass with the given flags turned on”. The three forms below are equivalent:

Examples

>>> class Foo(TensorClass["autocast"]):     # bracket form
...     integer: int
>>> class Foo(TensorClass, autocast=True):  # metaclass kwargs form
...     integer: int
>>> @tensorclass(autocast=True)             # legacy decorator form
... class Foo:
...     integer: int

The bracket form is usually the most readable when you stack several flags and it is the form static type-checkers (mypy/pyright) understand via __class_getitem__(). The kwargs form is convenient if the flag value is computed; the decorator form is best when migrating plain @dataclass code.

Several flags can be combined inside the brackets:

Examples

>>> class Foo(TensorClass["nocast", "frozen"]):
...     integer: int

Per-flag semantics. The flags fall into three families — casting (autocast, nocast, tensor_only, which are mutually exclusive), mutability (frozen) and name shadowing (shadow).

Default behaviour (no flags) — tensor-compatible values are wrapped as tensors and non-tensor values are stored as NonTensorData:

>>> class Foo(TensorClass):
...     x: int
>>> Foo(x=1).x
tensor(1)

autocast — coerce values back to the field’s annotated type when reading them. Useful when a field is conceptually a scalar/string/enum but should still travel through the TensorDict machinery as a leaf:

>>> class Foo(TensorClass["autocast"]):
...     x: int
>>> Foo(x=torch.ones(())).x
1

nocast — opt out of tensor casting entirely; tensor-compatible scalars are stored as-is:

>>> class Foo(TensorClass["nocast"]):
...     x: int
>>> Foo(x=1).x
1

tensor_only — the strict, fast path. Every field must be (or convert to) a tensor; the non-tensor storage is bypassed, which makes attribute access cheaper. Reach for this on performance-sensitive containers (e.g. RL trajectories, batched model I/O) where you control every field:

>>> class Foo(TensorClass["tensor_only"]):
...     x: torch.Tensor

frozen — make the class immutable, mirroring @dataclass(frozen=True). Frozen instances play well with torch.compile and functional code paths where in-place mutation would be a foot-gun. Note that frozen propagates to subclasses: a non-frozen subclass cannot inherit from a frozen base, and vice versa.

>>> class Foo(TensorClass["frozen"]):
...     x: torch.Tensor
>>> foo = Foo(x=torch.zeros(3))
>>> foo.x = torch.ones(3)  # raises FrozenInstanceError

shadow — by default, field names that collide with TensorDict’s reserved attributes (batch_size, device, names, data, …) raise an AttributeError at class construction. shadow=True opts out of that check so you can use those names anyway:

>>> class Foo(TensorClass["shadow"]):
...     data: torch.Tensor

Subclassing. A parametrized class is just a class — you can subclass it and stack more flags:

>>> class Base(TensorClass["autocast"]):
...     x: int
>>> class Sub(Base, frozen=True):   # autocast inherited, frozen added
...     y: float

Type-checking. TensorClass[...] is implemented via __class_getitem__(), so mypy and pyright resolve it to the (parametrized) class itself rather than to a generic parameter. Annotated fields propagate as expected and editors offer attribute completion on instances.

Note

TensorClass itself is not decorated as a tensorclass — the dataclass machinery only fires on subclasses. This is intentional: we cannot anticipate whether frozen will be requested, and a non-frozen base cannot have a frozen subclass.

See also

dumps(prefix: Optional[str] = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False, robust_key: bool | None = True) Any

Saves the tensordict to disk.

This function is a proxy to memmap().

from_tensordict(tensordict: TensorDictBase, non_tensordict: Optional[dict] = None, safe: bool = True) Any

Tensor class wrapper to instantiate a new tensor class object.

Parameters:
  • tensordict (TensorDictBase) – Dictionary of tensor types

  • non_tensordict (dict) – Dictionary with non-tensor and nested tensor class objects

  • safe (bool) – Whether to raise an error if the tensordict is not a TensorDictBase instance

get(key: NestedKey, *args, **kwargs)

Gets the value stored with the input key.

Parameters:
  • key (str, tuple of str) – key to be queried. If tuple of str it is equivalent to chained calls of getattr.

  • default – default value if the key is not found in the tensorclass.

Returns:

value stored with the input key

classmethod load(prefix: str | pathlib.Path, *args, **kwargs) Any

Loads a tensordict from disk.

This class method is a proxy to load_memmap().

load_(prefix: str | pathlib.Path, *args, **kwargs)

Loads a tensordict from disk within the current tensordict.

This class method is a proxy to load_memmap_().

classmethod load_memmap(prefix: str | pathlib.Path, device: Optional[device] = None, non_blocking: bool = False, *, out: Optional[TensorDictBase] = None, robust_key: bool | None = True) Any

Loads a memory-mapped tensordict from disk.

Parameters:
  • prefix (str or Path to folder) – the path to the folder where the saved tensordict should be fetched.

  • device (torch.device or equivalent, optional) – if provided, the data will be asynchronously cast to that device. Supports “meta” device, in which case the data isn’t loaded but a set of empty “meta” tensors are created. This is useful to get a sense of the total model size and structure without actually opening any file.

  • non_blocking (bool, optional) – if True, synchronize won’t be called after loading tensors on device. Defaults to False.

  • out (TensorDictBase, optional) – optional tensordict where the data should be written.

  • robust_key (bool, optional) – if True (default), expects robust key encoding was used when saving and decodes filenames accordingly. If False, uses legacy behavior. If None, uses the default robust behavior.

Examples

>>> from tensordict import TensorDict
>>> td = TensorDict.fromkeys(["a", "b", "c", ("nested", "e")], 0)
>>> td.memmap("./saved_td")
>>> td_load = TensorDict.load_memmap("./saved_td")
>>> assert (td == td_load).all()

This method also allows loading nested tensordicts.

Examples

>>> nested = TensorDict.load_memmap("./saved_td/nested")
>>> assert nested["e"] == 0

A tensordict can also be loaded on “meta” device or, alternatively, as a fake tensor.

Examples

>>> import tempfile
>>> td = TensorDict({"a": torch.zeros(()), "b": {"c": torch.zeros(())}})
>>> with tempfile.TemporaryDirectory() as path:
...     td.save(path)
...     td_load = TensorDict.load_memmap(path, device="meta")
...     print("meta:", td_load)
...     from torch._subclasses import FakeTensorMode
...     with FakeTensorMode():
...         td_load = TensorDict.load_memmap(path)
...         print("fake:", td_load)
meta: TensorDict(
    fields={
        a: Tensor(shape=torch.Size([]), device=meta, dtype=torch.float32, is_shared=False),
        b: TensorDict(
            fields={
                c: Tensor(shape=torch.Size([]), device=meta, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=meta,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=meta,
    is_shared=False)
fake: TensorDict(
    fields={
        a: FakeTensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        b: TensorDict(
            fields={
                c: FakeTensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=cpu,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=cpu,
    is_shared=False)
load_state_dict(state_dict: dict[str, Any], strict=True, assign=False, from_flatten=None)

Loads a state_dict into the tensorclass.

Supports both the new format (logical keys with _metadata) and the legacy format (_tensordict/_non_tensordict wrapper keys).

memmap(prefix: Optional[str] = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False, existsok: bool = True, robust_key: bool | None = True) Any

Writes all tensors onto a corresponding memory-mapped Tensor in a new tensordict.

Parameters:
  • prefix (str) – directory prefix where the memory-mapped tensors will be stored. The directory tree structure will mimic the tensordict’s.

  • copy_existing (bool) – If False (default), an exception will be raised if an entry in the tensordict is already a tensor stored on disk with an associated file, but is not saved in the correct location according to prefix. If True, any existing Tensor will be copied to the new location.

Keyword Arguments:
  • num_threads (int, optional) – the number of threads used to write the memmap tensors. Defaults to 0.

  • return_early (bool, optional) – if True and num_threads>0, the method will return a future of the tensordict.

  • share_non_tensor (bool, optional) – if True, the non-tensor data will be shared between the processes and writing operation (such as inplace update or set) on any of the workers within a single node will update the value on all other workers. If the number of non_tensor leaves is high (e.g., sharing large stacks of non-tensor data) this may result in OOM or similar errors. Defaults to False.

  • existsok (bool, optional) – if False, an exception will be raised if a tensor already exists in the same path. Defaults to True.

  • robust_key (bool, optional) – if True (default), uses robust key encoding that safely handles keys with path separators and special characters. If False, uses legacy behavior (keys used as-is). If None, uses the default robust behavior.

The TensorDict is then locked, meaning that any writing operations that isn’t in-place will throw an exception (eg, rename, set or remove an entry). Once the tensordict is unlocked, the memory-mapped attribute is turned to False, because cross-process identity is not guaranteed anymore.

Returns:

A new tensordict with the tensors stored on disk if return_early=False, otherwise a TensorDictFuture instance.

Note

Serialising in this fashion might be slow with deeply nested tensordicts, so it is not recommended to call this method inside a training loop.

memmap_(prefix: Optional[str] = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False, existsok: bool = True, robust_key: bool | None = True) Any

Writes all tensors onto a corresponding memory-mapped Tensor, in-place.

Parameters:
  • prefix (str) – directory prefix where the memory-mapped tensors will be stored. The directory tree structure will mimic the tensordict’s.

  • copy_existing (bool) – If False (default), an exception will be raised if an entry in the tensordict is already a tensor stored on disk with an associated file, but is not saved in the correct location according to prefix. If True, any existing Tensor will be copied to the new location.

Keyword Arguments:
  • num_threads (int, optional) – the number of threads used to write the memmap tensors. Defaults to 0.

  • return_early (bool, optional) – if True and num_threads>0, the method will return a future of the tensordict. The resulting tensordict can be queried using future.result().

  • share_non_tensor (bool, optional) – if True, the non-tensor data will be shared between the processes and writing operation (such as inplace update or set) on any of the workers within a single node will update the value on all other workers. If the number of non-tensor leaves is high (e.g., sharing large stacks of non-tensor data) this may result in OOM or similar errors. Defaults to False.

  • existsok (bool, optional) – if False, an exception will be raised if a tensor already exists in the same path. Defaults to True.

  • robust_key (bool, optional) – if True (default), uses robust key encoding that safely handles keys with path separators and special characters. If False, uses legacy behavior (keys used as-is). If None, uses the default robust behavior.

The TensorDict is then locked, meaning that any writing operations that isn’t in-place will throw an exception (eg, rename, set or remove an entry). Once the tensordict is unlocked, the memory-mapped attribute is turned to False, because cross-process identity is not guaranteed anymore.

Returns:

self if return_early=False, otherwise a TensorDictFuture instance.

Note

Serialising in this fashion might be slow with deeply nested tensordicts, so it is not recommended to call this method inside a training loop.

memmap_like(prefix: Optional[str] = None, copy_existing: bool = False, *, existsok: bool = True, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False, robust_key: bool | None = True) Any

Creates a contentless Memory-mapped tensordict with the same shapes as the original one.

Parameters:
  • prefix (str) – directory prefix where the memory-mapped tensors will be stored. The directory tree structure will mimic the tensordict’s.

  • copy_existing (bool) – If False (default), an exception will be raised if an entry in the tensordict is already a tensor stored on disk with an associated file, but is not saved in the correct location according to prefix. If True, any existing Tensor will be copied to the new location.

Keyword Arguments:
  • num_threads (int, optional) – the number of threads used to write the memmap tensors. Defaults to 0.

  • return_early (bool, optional) – if True and num_threads>0, the method will return a future of the tensordict.

  • share_non_tensor (bool, optional) – if True, the non-tensor data will be shared between the processes and writing operation (such as inplace update or set) on any of the workers within a single node will update the value on all other workers. If the number of non-tensor leaves is high (e.g., sharing large stacks of non-tensor data) this may result in OOM or similar errors. Defaults to False.

  • existsok (bool, optional) – if False, an exception will be raised if a tensor already exists in the same path. Defaults to True.

  • robust_key (bool, optional) – if True (default), uses robust key encoding that safely handles keys with path separators and special characters. If False, uses legacy behavior (keys used as-is). If None, uses the default robust behavior.

The TensorDict is then locked, meaning that any writing operations that isn’t in-place will throw an exception (eg, rename, set or remove an entry). Once the tensordict is unlocked, the memory-mapped attribute is turned to False, because cross-process identity is not guaranteed anymore.

Returns:

A new TensorDict instance with data stored as memory-mapped tensors if return_early=False, otherwise a TensorDictFuture instance.

Note

This is the recommended method to write a set of large buffers on disk, as memmap_() will copy the information, which can be slow for large content.

Examples

>>> td = TensorDict({
...     "a": torch.zeros((3, 64, 64), dtype=torch.uint8),
...     "b": torch.zeros(1, dtype=torch.int64),
... }, batch_size=[]).expand(1_000_000)  # expand does not allocate new memory
>>> buffer = td.memmap_like("/path/to/dataset")
memmap_refresh_()

Refreshes the content of the memory-mapped tensordict if it has a saved_path.

This method will raise an exception if no path is associated with it.

save(prefix: Optional[str] = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False, robust_key: bool | None = True) Any

Saves the tensordict to disk.

This function is a proxy to memmap().

set(key: NestedKey, value: Any, inplace: bool = False, non_blocking: bool = False)

Sets a new key-value pair.

Parameters:
  • key (str, tuple of str) – name of the key to be set. If tuple of str it is equivalent to chained calls of getattr followed by a final setattr.

  • value (Any) – value to be stored in the tensorclass

  • inplace (bool, optional) – if True, set will tentatively try to update the value in-place. If False or if the key isn’t present, the value will be simply written at its destination.

Returns:

self

state_dict(destination=None, prefix='', keep_vars=False, flatten=True) dict[str, Any]

Returns a state_dict with logical keys, matching TensorDictBase conventions.

Tensor fields appear as data keys. Non-tensor fields (strings, ints, etc.) and the tensorclass type are stored in _metadata. This replaces the legacy _tensordict/_non_tensordict wrapper format.

to_tensordict(*, retain_none: Optional[bool] = None) TensorDict

Convert the tensorclass into a regular TensorDict.

Makes a copy of all entries. Memmap and shared memory tensors are converted to regular tensors.

Parameters:

retain_none (bool) – if True, the None values will be written in the tensordict. Otherwise they will be discrarded. Default: True.

Returns:

A new TensorDict object containing the same values as the tensorclass.

unbind(dim: int)

Returns a tuple of indexed tensorclass instances unbound along the indicated dimension.

Resulting tensorclass instances will share the storage of the initial tensorclass instance.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources