Shortcuts

TypedTensorDict

TypedTensorDict is a TensorDictBase subclass with typed field declarations and backend composition. It brings TypedDict-style class definitions to TensorDict: you declare fields as class annotations and get typed construction, typed attribute access, inheritance, NotRequired fields, **state spreading, and the ability to wrap any TensorDictBase backend (H5, Redis, lazy stacks, etc.) via from_tensordict.

>>> import torch
>>> from tensordict import TypedTensorDict
>>> from torch import Tensor
>>>
>>> class PredictorState(TypedTensorDict):
...     eta: Tensor
...     X: Tensor
...     beta: Tensor
>>>
>>> state = PredictorState(
...     eta=torch.randn(5, 3),
...     X=torch.randn(5, 4),
...     beta=torch.randn(5, 1),
...     batch_size=[5],
... )
>>> state.eta.shape
torch.Size([5, 3])
>>> state["X"].shape
torch.Size([5, 4])

Why TypedTensorDict?

Typed pipelines often build up state one step at a time:

class PredictorState(TypedTensorDict):
    eta: Tensor
    X: Tensor
    beta: Tensor

class ObservedState(PredictorState):
    y: Tensor
    mu: Tensor

def gaussian(state: PredictorState, std: float) -> ObservedState:
    eta = state.eta
    y = eta + torch.randn_like(eta) * std
    return ObservedState(**state, y=y, mu=eta, batch_size=state.batch_size)

Each stage inherits the previous one’s fields and adds new ones. The **state spreading pattern lets transition functions stay short regardless of how many fields the state has. And because TypedTensorDict inherits from TensorDictBase, every operation – .to(device), .clone(), slicing, torch.stack, memmap – works at every stage.

TypedTensorDict vs TensorClass

Both TypedTensorDict and TensorClass provide typed tensor containers. They share the same class-option syntax (["shadow"], ["frozen"], etc.) and both use @dataclass_transform() for IDE support. The key difference is in the underlying model:

Feature

TypedTensorDict

TensorClass

Inherits from

TensorDictBase (delegates to _source)

TensorCollection (wraps a TensorDict internally)

Can wrap any backend

Yes (from_tensordict)

Yes (from_tensordict)

Inheritance

Standard Python (class Child(Parent): ...)

Supported via metaclass

**state spreading

Works natively (MutableMapping)

Requires manual field-by-field repacking

state["key"]

Works natively (TensorDictBase.__getitem__)

Raises ValueError – use state.key or state.get("key")

NotRequired fields

Supported

Not supported

Non-tensor fields

Not supported (tensor-only)

Supported (strings, ints, arbitrary objects)

Custom methods

Supported (regular class methods)

Supported (regular class methods)

@tensorclass decorator

Not needed (uses metaclass via inheritance)

Required (or class Foo(TensorClass): ...)

When to use which:

  • Use TypedTensorDict when you have a typed pipeline with progressive state accumulation, need **state spreading, want standard Python inheritance for schema composition, or need to wrap persistent backends while keeping full TensorDictBase API compatibility.

  • Use TensorClass when you need non-tensor fields (strings, metadata), custom __init__ logic, or your codebase already uses @tensorclass extensively.

Inheritance and field accumulation

Fields accumulate through the MRO. Each subclass adds its own fields while inheriting all parent fields:

>>> from typing import NotRequired
>>>
>>> class PredictorState(TypedTensorDict):
...     eta: Tensor
...     X: Tensor
...     beta: Tensor
>>>
>>> class ObservedState(PredictorState):
...     y: Tensor
...     mu: Tensor
...     noise: NotRequired[Tensor]
>>>
>>> class SurvivalState(ObservedState):
...     event_time: Tensor
...     indicator: Tensor
...     observed_time: Tensor

>>> ObservedState.__required_keys__
frozenset({'eta', 'X', 'beta', 'y', 'mu'})
>>> ObservedState.__optional_keys__
frozenset({'noise'})

Inheritance works as standard Python: isinstance(obs, PredictorState) returns True for an ObservedState instance, and a function typed as f(state: PredictorState) accepts any subclass.

NotRequired fields

Mark fields as optional with NotRequired:

>>> from typing import NotRequired
>>>
>>> class ObservedState(PredictorState):
...     y: Tensor
...     mu: Tensor
...     noise: NotRequired[Tensor]

>>> obs = ObservedState(
...     eta=torch.randn(5, 3), X=torch.randn(5, 4), beta=torch.randn(5, 1),
...     y=torch.randn(5, 3), mu=torch.randn(5, 3),
...     batch_size=[5],
... )
>>> "noise" in obs
False

If a NotRequired field is not provided, it is simply absent from the underlying TensorDict. Accessing it via attribute raises AttributeError.

Spreading (**state)

Because TypedTensorDict is a MutableMapping, the ** operator unpacks it into keyword arguments. This makes state transitions concise:

>>> state = PredictorState(
...     eta=torch.randn(5, 3), X=torch.randn(5, 4), beta=torch.randn(5, 1),
...     batch_size=[5],
... )
>>> obs = ObservedState(
...     **state,
...     y=torch.randn(5, 3),
...     mu=torch.randn(5, 3),
...     batch_size=state.batch_size,
... )
>>> set(obs.keys()) == {"eta", "X", "beta", "y", "mu"}
True

Adding a new field to a pipeline stage is one line in the class definition – no transition function needs updating.

Class options

TypedTensorDict supports the same bracket-syntax options as TensorClass:

class MyModel(TypedTensorDict["shadow"]):
    data: Tensor      # "data" shadows TensorDict.data -- allowed

class Immutable(TypedTensorDict["frozen"]):
    x: Tensor         # locked after construction

class Combined(TypedTensorDict["shadow", "frozen"]):
    data: Tensor
  • "shadow" – Allow field names that clash with TensorDictBase attributes. Without this, conflicting names raise AttributeError at class definition time.

  • "frozen" – Lock the TensorDict after construction (read-only).

  • "autocast" – Automatically cast assigned values.

  • "nocast" – Disable type casting on assignment.

  • "tensor_only" – Restrict fields to tensor types only.

Options propagate through inheritance: a subclass of a "frozen" class is also frozen.

Backend composition (from_tensordict)

TypedTensorDict can wrap any TensorDictBase backend via from_tensordict(td). The backend is stored live (zero-copy); mutations through the typed wrapper go directly to the underlying storage:

>>> from tensordict import TensorDict
>>>
>>> td = TensorDict(
...     eta=torch.randn(5, 3), X=torch.randn(5, 4), beta=torch.randn(5, 1),
...     batch_size=[5],
... )
>>> state = PredictorState.from_tensordict(td)
>>> state.eta.shape
torch.Size([5, 3])
>>> state.eta = torch.ones(5, 3)  # writes to td
>>> (td["eta"] == 1).all()
True

This works with any backend: PersistentTensorDict (H5), TensorDictStore (Redis), LazyStackedTensorDict, memory-mapped TensorDict, etc. See the Cross-class compatibility page for full details and examples.

Pass check=False to wrap an empty or partially-filled backend (e.g. a pre-allocated TensorDictStore). Missing fields will raise at access time rather than at wrap time:

>>> store = TensorDictStore(batch_size=[100_000], host="redis-node")
>>> state = PredictorState.from_tensordict(store, check=False)

See pre-allocating on Redis in the compatibility guide for the full workflow.

TensorDict operations

Every TensorDictBase operation works on TypedTensorDict instances:

>>> state = PredictorState(
...     eta=torch.randn(5, 3), X=torch.randn(5, 4), beta=torch.randn(5, 1),
...     batch_size=[5],
... )
>>> state.to("cpu").device
device(type='cpu')
>>> state.clone()["eta"].shape
torch.Size([5, 3])
>>> state[0:3].batch_size
torch.Size([3])
>>> torch.stack([state, state], dim=0).batch_size
torch.Size([2, 5])

This includes .memmap(), .apply(), torch.cat, torch.stack, .unbind(), .select(), .exclude(), .update(), and all other TensorDictBase methods.

Type checking

TypedTensorDict uses @dataclass_transform() (PEP 681) on its metaclass. This means type checkers (pyright, mypy) understand:

  • Constructor signatures – missing or extra fields are flagged.

  • Attribute accessstate.eta is typed as Tensor, and typos like state.etta produce errors.

  • Inheritance – subclass fields include parent fields.

String-key access (state["eta"]) works at runtime but does not get type narrowing without a dedicated type checker plugin. For typed access, prefer dot notation (state.eta).

TypedTensorDict(*args, **kwargs)

A TensorDictBase subclass with typed field declarations and backend composition.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources