TensorSpec System¶
TensorSpec classes define the shape, dtype, and domain of tensors in TorchRL.
|
A binary discrete tensor spec. |
|
A bounded tensor spec. |
|
A discrete tensor spec. |
|
A composition of TensorSpecs. |
|
A concatenation of discrete tensor spec. |
|
A concatenation of one-hot discrete tensor spec. |
|
A spec for non-tensor data. |
|
A unidimensional, one-hot discrete tensor spec. |
|
A lazy representation of a stack of tensor specs. |
|
A lazy representation of a stack of composite specs. |
|
Parent class of the tensor meta-data containers. |
|
An unbounded tensor spec. |
|
A specialized version of |
|
A specialized version of |
Supported PyTorch Operations¶
TensorSpec classes support various PyTorch-like operations for manipulating their shape and structure. These operations return new spec instances with the modified shape while preserving dtype, device, and domain information.
PyTorch function overloads (via __torch_function__):
These can be called using the standard PyTorch functional API:
import torch
from torchrl.data import Bounded, Composite
# torch.stack - stack multiple specs along a new dimension
spec1 = Bounded(low=0, high=1, shape=(3, 4))
spec2 = Bounded(low=0, high=1, shape=(3, 4))
stacked = torch.stack([spec1, spec2], dim=0) # shape: (2, 3, 4)
# torch.squeeze / torch.unsqueeze - remove or add singleton dimensions
spec = Bounded(low=0, high=1, shape=(1, 3, 4))
squeezed = torch.squeeze(spec, dim=0) # shape: (3, 4)
unsqueezed = torch.unsqueeze(squeezed, dim=0) # shape: (1, 3, 4)
# torch.index_select - select indices along a dimension
spec = Bounded(low=0, high=1, shape=(5, 4))
selected = torch.index_select(spec, dim=0, index=torch.tensor([0, 2, 4])) # shape: (3, 4)
Instance methods:
TensorSpec also provides instance methods that mirror common tensor operations:
expand()- broadcast the spec to a larger shapesqueeze()- remove singleton dimensionsunsqueeze()- add a singleton dimensionreshape()- reshape the spec to a new shapeflatten()- flatten dimensionsunflatten()- unflatten a dimension into multiple dimensionsunbind()- split the spec along a dimension
from torchrl.data import Bounded
spec = Bounded(low=0, high=1, shape=(2, 3, 4))
# Reshape operations
reshaped = spec.reshape(6, 4) # shape: (6, 4)
flattened = spec.flatten(0, 1) # shape: (6, 4)
expanded = spec.expand(5, 2, 3, 4) # shape: (5, 2, 3, 4)
# Split operations
unbound = spec.unbind(dim=0) # tuple of 2 specs, each with shape (3, 4)
Note
Some operations have restrictions for discrete specs like OneHot, MultiOneHot,
and Binary, where the last dimension represents the domain and cannot be modified.
For example, torch.index_select along the last dimension of a OneHot spec will raise
a ValueError.