tensorclass¶
- class tensordict.tensorclass(cls: Optional[T] = None, /, *, autocast: bool = False, frozen: bool = False, nocast: bool = False, shadow: bool = False, tensor_only: bool = False)¶
- A decorator to create - tensorclassclasses.- tensorclassclasses are specialized- dataclasses.dataclass()instances that can execute some pre-defined tensor operations out of the box, such as indexing, item assignment, reshaping, casting to device or storage and many others.- Keyword Arguments:
- autocast (bool, optional) – if - True, the types indicated will be enforced when an argument is set. Thie argument is exclusive with- autocast(both cannot be true at the same time). Defaults to- False.
- frozen (bool, optional) – if - True, the content of the tensorclass cannot be modified. This argument is provided to dataclass-compatibility, a similar behavior can be obtained through the lock argument in the class constructor. Defaults to- False.
- nocast (bool, optional) – if - True, Tensor-compatible types such as- int,- np.ndarrayand the like will not be cast to a tensor type. Thie argument is exclusive with- autocast(both cannot be true at the same time). Defaults to- False.
- shadow (bool, optional) – Disables the validation of field names against TensorDict’s reserved attributes. Use with caution, as this may cause unintended consequences. Defaults to False. 
- tensor_only (bool, optional) – if - True, it is expected that all items in tensorclass will be tensor instances (tensor-compatible, since non-tensor data is converted to tensors if possible). This can bring significant speed-ups at the cost of flexible interactions with non-tensor data. Defaults to- False.
 
 - tensorclass can be used with or without arguments: - Examples - >>> @tensorclass ... class X: ... y: int >>> X(torch.ones(())).y tensor(1.) >>> @tensorclass(autocast=False) ... class X: ... y: int >>> X(torch.ones(())).y tensor(1.) >>> @tensorclass(autocast=True) ... class X: ... y: int >>> X(torch.ones(())).y 1 >>> @tensorclass(nocast=True) ... class X: ... y: Any >>> X(1).y 1 >>> @tensorclass(nocast=False) ... class X: ... y: Any >>> X(1).y tensor(1) - Examples - >>> from tensordict import tensorclass >>> import torch >>> from typing import Optional >>> >>> @tensorclass ... class MyData: ... X: torch.Tensor ... y: torch.Tensor ... z: str ... def expand_and_mask(self): ... X = self.X.unsqueeze(-1).expand_as(self.y) ... X = X[self.y] ... return X ... >>> data = MyData( ... X=torch.ones(3, 4, 1), ... y=torch.zeros(3, 4, 2, 2, dtype=torch.bool), ... z="test" ... batch_size=[3, 4]) >>> print(data) MyData( X=Tensor(torch.Size([3, 4, 1]), dtype=torch.float32), y=Tensor(torch.Size([3, 4, 2, 2]), dtype=torch.bool), z="test" batch_size=[3, 4], device=None, is_shared=False) >>> print(data.expand_and_mask()) tensor([]) - It is also possible to nest tensorclasses instances within each other:
- Examples: >>> from tensordict import tensorclass >>> import torch >>> from typing import Optional >>> >>> @tensorclass … class NestingMyData: … nested: MyData … >>> nesting_data = NestingMyData(nested=data, batch_size=[3, 4]) >>> # although the data is stored as a TensorDict, the type hint helps us >>> # to appropriately cast the data to the right type >>> assert isinstance(nesting_data.nested, type(data))