TypedTensorDict¶
TypedTensorDict is a TensorDictBase subclass
with typed field declarations and backend composition. It brings TypedDict-style
class definitions to TensorDict: you declare fields as class annotations and get
typed construction, typed attribute access, inheritance, NotRequired fields,
**state spreading, and the ability to wrap any TensorDictBase backend
(H5, Redis, lazy stacks, etc.) via from_tensordict.
>>> import torch
>>> from tensordict import TypedTensorDict
>>> from torch import Tensor
>>>
>>> class PredictorState(TypedTensorDict):
... eta: Tensor
... X: Tensor
... beta: Tensor
>>>
>>> state = PredictorState(
... eta=torch.randn(5, 3),
... X=torch.randn(5, 4),
... beta=torch.randn(5, 1),
... batch_size=[5],
... )
>>> state.eta.shape
torch.Size([5, 3])
>>> state["X"].shape
torch.Size([5, 4])
Why TypedTensorDict?¶
Typed pipelines often build up state one step at a time:
class PredictorState(TypedTensorDict):
eta: Tensor
X: Tensor
beta: Tensor
class ObservedState(PredictorState):
y: Tensor
mu: Tensor
def gaussian(state: PredictorState, std: float) -> ObservedState:
eta = state.eta
y = eta + torch.randn_like(eta) * std
return ObservedState(**state, y=y, mu=eta, batch_size=state.batch_size)
Each stage inherits the previous one’s fields and adds new ones. The
**state spreading pattern lets transition functions stay short regardless
of how many fields the state has. And because TypedTensorDict inherits
from TensorDictBase, every operation – .to(device), .clone(),
slicing, torch.stack, memmap – works at every stage.
TypedTensorDict vs TensorClass¶
Both TypedTensorDict and TensorClass provide typed tensor containers.
They share the same class-option syntax (["shadow"], ["frozen"], etc.)
and both use @dataclass_transform() for IDE support. The key difference is
in the underlying model:
Feature |
|
|
|---|---|---|
Inherits from |
|
|
Can wrap any backend |
Yes ( |
Yes ( |
Inheritance |
Standard Python ( |
Supported via metaclass |
|
Works natively ( |
Requires manual field-by-field repacking |
|
Works natively ( |
Raises |
|
Supported |
Not supported |
Non-tensor fields |
Not supported (tensor-only) |
Supported (strings, ints, arbitrary objects) |
Custom methods |
Supported (regular class methods) |
Supported (regular class methods) |
|
Not needed (uses metaclass via inheritance) |
Required (or |
When to use which:
Use
TypedTensorDictwhen you have a typed pipeline with progressive state accumulation, need**statespreading, want standard Python inheritance for schema composition, or need to wrap persistent backends while keeping fullTensorDictBaseAPI compatibility.Use
TensorClasswhen you need non-tensor fields (strings, metadata), custom__init__logic, or your codebase already uses@tensorclassextensively.
Inheritance and field accumulation¶
Fields accumulate through the MRO. Each subclass adds its own fields while inheriting all parent fields:
>>> from typing import NotRequired
>>>
>>> class PredictorState(TypedTensorDict):
... eta: Tensor
... X: Tensor
... beta: Tensor
>>>
>>> class ObservedState(PredictorState):
... y: Tensor
... mu: Tensor
... noise: NotRequired[Tensor]
>>>
>>> class SurvivalState(ObservedState):
... event_time: Tensor
... indicator: Tensor
... observed_time: Tensor
>>> ObservedState.__required_keys__
frozenset({'eta', 'X', 'beta', 'y', 'mu'})
>>> ObservedState.__optional_keys__
frozenset({'noise'})
Inheritance works as standard Python: isinstance(obs, PredictorState)
returns True for an ObservedState instance, and a function typed as
f(state: PredictorState) accepts any subclass.
NotRequired fields¶
Mark fields as optional with NotRequired:
>>> from typing import NotRequired
>>>
>>> class ObservedState(PredictorState):
... y: Tensor
... mu: Tensor
... noise: NotRequired[Tensor]
>>> obs = ObservedState(
... eta=torch.randn(5, 3), X=torch.randn(5, 4), beta=torch.randn(5, 1),
... y=torch.randn(5, 3), mu=torch.randn(5, 3),
... batch_size=[5],
... )
>>> "noise" in obs
False
If a NotRequired field is not provided, it is simply absent from the
underlying TensorDict. Accessing it via attribute raises
AttributeError.
Spreading (**state)¶
Because TypedTensorDict is a MutableMapping, the ** operator
unpacks it into keyword arguments. This makes state transitions concise:
>>> state = PredictorState(
... eta=torch.randn(5, 3), X=torch.randn(5, 4), beta=torch.randn(5, 1),
... batch_size=[5],
... )
>>> obs = ObservedState(
... **state,
... y=torch.randn(5, 3),
... mu=torch.randn(5, 3),
... batch_size=state.batch_size,
... )
>>> set(obs.keys()) == {"eta", "X", "beta", "y", "mu"}
True
Adding a new field to a pipeline stage is one line in the class definition – no transition function needs updating.
Class options¶
TypedTensorDict supports the same bracket-syntax options as TensorClass:
class MyModel(TypedTensorDict["shadow"]):
data: Tensor # "data" shadows TensorDict.data -- allowed
class Immutable(TypedTensorDict["frozen"]):
x: Tensor # locked after construction
class Combined(TypedTensorDict["shadow", "frozen"]):
data: Tensor
"shadow"– Allow field names that clash withTensorDictBaseattributes. Without this, conflicting names raiseAttributeErrorat class definition time."frozen"– Lock theTensorDictafter construction (read-only)."autocast"– Automatically cast assigned values."nocast"– Disable type casting on assignment."tensor_only"– Restrict fields to tensor types only.
Options propagate through inheritance: a subclass of a "frozen" class is
also frozen.
Backend composition (from_tensordict)¶
TypedTensorDict can wrap any TensorDictBase backend via
from_tensordict(td). The backend is stored live (zero-copy); mutations
through the typed wrapper go directly to the underlying storage:
>>> from tensordict import TensorDict
>>>
>>> td = TensorDict(
... eta=torch.randn(5, 3), X=torch.randn(5, 4), beta=torch.randn(5, 1),
... batch_size=[5],
... )
>>> state = PredictorState.from_tensordict(td)
>>> state.eta.shape
torch.Size([5, 3])
>>> state.eta = torch.ones(5, 3) # writes to td
>>> (td["eta"] == 1).all()
True
This works with any backend: PersistentTensorDict (H5),
TensorDictStore (Redis), LazyStackedTensorDict, memory-mapped
TensorDict, etc. See the Cross-class compatibility page for full
details and examples.
Pass check=False to wrap an empty or partially-filled backend (e.g. a
pre-allocated TensorDictStore). Missing fields will raise at access
time rather than at wrap time:
>>> store = TensorDictStore(batch_size=[100_000], host="redis-node")
>>> state = PredictorState.from_tensordict(store, check=False)
See pre-allocating on Redis in the compatibility guide for the full workflow.
TensorDict operations¶
Every TensorDictBase operation works on TypedTensorDict instances:
>>> state = PredictorState(
... eta=torch.randn(5, 3), X=torch.randn(5, 4), beta=torch.randn(5, 1),
... batch_size=[5],
... )
>>> state.to("cpu").device
device(type='cpu')
>>> state.clone()["eta"].shape
torch.Size([5, 3])
>>> state[0:3].batch_size
torch.Size([3])
>>> torch.stack([state, state], dim=0).batch_size
torch.Size([2, 5])
This includes .memmap(), .apply(), torch.cat, torch.stack,
.unbind(), .select(), .exclude(), .update(), and all other
TensorDictBase methods.
Type checking¶
TypedTensorDict uses @dataclass_transform() (PEP 681) on its metaclass.
This means type checkers (pyright, mypy) understand:
Constructor signatures – missing or extra fields are flagged.
Attribute access –
state.etais typed asTensor, and typos likestate.ettaproduce errors.Inheritance – subclass fields include parent fields.
String-key access (state["eta"]) works at runtime but does not get type
narrowing without a dedicated type checker plugin. For typed access, prefer
dot notation (state.eta).
|
A TensorDictBase subclass with typed field declarations and backend composition. |