Shortcuts

TensorClass

class tensordict.TensorClass

TensorClass is the inheritance-based version of the @tensorclass decorator.

TensorClass allows you to code dataclasses that are better type-checked and more pythonic than those built with the @tensorclass decorator.

Examples

>>> from typing import Any
>>> import torch
>>> from tensordict import TensorClass
>>> class Foo(TensorClass):
...     tensor: torch.Tensor
...     non_tensor: Any
...     nested: Any = None
>>> foo = Foo(tensor=torch.randn(3), non_tensor="a string!", nested=None, batch_size=[3])
>>> print(foo)
Foo(
    non_tensor=NonTensorData(data=a string!, batch_size=torch.Size([3]), device=None),
    tensor=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
    nested=None,
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
Keyword Arguments:
  • batch_size (torch.Size, optional) – The batch size of the TensorDict. Defaults to None.

  • device (torch.device, optional) – The device on which the TensorDict will be created. Defaults to None.

  • frozen (bool, optional) – If True, the resulting class or instance will be immutable. Defaults to False.

  • autocast (bool, optional) – If True, enables automatic type casting for the resulting class or instance. Defaults to False.

  • nocast (bool, optional) – If True, disables any type casting for the resulting class or instance. Defaults to False.

  • tensor_only (bool, optional) – if True, it is expected that all items in tensorclass will be tensor instances (tensor-compatible, since non-tensor data is converted to tensors if possible). This can bring significant speed-ups at the cost of flexible interactions with non-tensor data. Defaults to False.

  • shadow (bool, optional) – Disables the validation of field names against TensorDict’s reserved attributes. Use with caution, as this may cause unintended consequences. Defaults to False.

You can pass boolean keyword arguments (“autocast”, “nocast”, “frozen”, “tensor_only”, “shadow”) in two ways: using

brackets or keyword arguments.

Examples

>>> class Foo(TensorClass["autocast"]):
...     integer: int
>>> Foo(integer=torch.ones(())).integer
1
>>> class Foo(TensorClass, autocast=True):  # equivalent
...     integer: int
>>> Foo(integer=torch.ones(())).integer
1
>>> class Foo(TensorClass["nocast"]):
...     integer: int
>>> Foo(integer=1).integer
1
>>> class Foo(TensorClass["nocast", "frozen"]):  # multiple keywords can be used
...     integer: int
>>> Foo(integer=1).integer
1
>>> class Foo(TensorClass, nocast=True):  # equivalent
...     integer: int
>>> Foo(integer=1).integer
1
>>> class Foo(TensorClass):
...     integer: int
>>> Foo(integer=1).integer
tensor(1)

Warning

TensorClass itself is not decorated as a tensorclass, but subclasses will be. This is because we cannot anticipate if the frozen argument will be set, and if it is, it may conflict with the parent class (a subclass cannot be frozen if the parent class isn’t).

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources