• Docs >
  • Cross-class compatibility
Shortcuts

Cross-class compatibility

TensorDict has several typed container wrappers and storage backends that can be composed together. This page documents which combinations work, which have caveats, and how to build these “chimera” objects.

Architecture overview

TensorCollection
├── TensorDictBase
│   ├── TensorDict                 (in-memory)
│   ├── TypedTensorDict            (typed fields, wraps any TensorDictBase)
│   ├── PersistentTensorDict       (HDF5-backed)
│   ├── TensorDictStore            (Redis / Dragonfly / KeyDB)
│   └── LazyStackedTensorDict      (lazy stack of heterogeneous TDs)
│
└── TensorClass                    (typed wrapper, HAS-A TensorDictBase)

Two patterns exist for adding typed field declarations:

  • TensorClass wraps any TensorDictBase via from_tensordict(td). It delegates all storage to the wrapped object.

  • TypedTensorDict wraps any TensorDictBase via from_tensordict(td), similar to TensorClass. Direct construction creates a TensorDict internally. Unlike TensorClass, it inherits from TensorDictBase directly, supports **state spreading natively, and uses standard Python inheritance for schema composition.

TensorClass + backends

TensorClass.from_tensordict(td) accepts any TensorDictBase subclass. The table below summarises which operations work on each combination.

from tensordict import TensorClass
from torch import Tensor

class MyTC(TensorClass):
    a: Tensor
    b: Tensor

tc = MyTC.from_tensordict(some_backend)

Backend

Build

Read

Write

Index

Clone

Stack

Iter

Update

TensorDict

yes

yes

yes

yes

yes

yes

yes

yes

PersistentTensorDict (H5)

yes

yes

yes

yes

yes

yes

yes

yes

TensorDictStore (Redis)

yes

yes

yes

yes

yes

yes

yes

yes

LazyStackedTensorDict

yes

yes

yes

yes

yes

yes

yes

yes

TensorDict (memmap)

yes

yes

set_()

yes

yes

yes

yes

update_()

TypedTensorDict

yes

yes

yes

yes

yes

yes

yes

yes

Note

Memory-mapped TensorDicts are locked after memmap_(). Use set_() and update_() for in-place writes instead of attribute assignment or update().

Building a TensorClass on each backend

In-memory TensorDict – the default and simplest case:

>>> import torch
>>> from tensordict import TensorDict, TensorClass
>>> from torch import Tensor
>>>
>>> class MyTC(TensorClass):
...     a: Tensor
...     b: Tensor
>>>
>>> td = TensorDict(a=torch.randn(4, 3), b=torch.randn(4, 5), batch_size=[4])
>>> tc = MyTC.from_tensordict(td)
>>> tc.a.shape
torch.Size([4, 3])

HDF5 (PersistentTensorDict):

>>> from tensordict import PersistentTensorDict
>>>
>>> h5 = PersistentTensorDict.from_dict(td, filename="data.h5")
>>> tc_h5 = MyTC.from_tensordict(h5)
>>> tc_h5.a.shape  # reads from HDF5
torch.Size([4, 3])

Redis (TensorDictStore):

>>> from tensordict.store import TensorDictStore
>>>
>>> store = TensorDictStore.from_tensordict(td, host="localhost")
>>> tc_redis = MyTC.from_tensordict(store)
>>> tc_redis.a.shape  # fetched from Redis
torch.Size([4, 3])

Lazy stack:

>>> from tensordict import lazy_stack
>>>
>>> tds = [TensorDict(a=torch.randn(3), b=torch.randn(5)) for _ in range(4)]
>>> ls = lazy_stack(tds, dim=0)
>>> tc_lazy = MyTC.from_tensordict(ls)
>>> tc_lazy[0].a.shape
torch.Size([3])

Memory-mapped TensorDict:

>>> td_mmap = td.memmap_("/tmp/my_memmap")
>>> tc_mmap = MyTC.from_tensordict(td_mmap)
>>> tc_mmap.a.shape
torch.Size([4, 3])
>>> # memmap TDs are locked -- use in-place operations:
>>> tc_mmap.set_("a", torch.ones(4, 3))

TypedTensorDict as backend – both TensorClass and TypedTensorDict enforce schemas, but they compose without conflict:

>>> from tensordict import TypedTensorDict
>>>
>>> class MyTTD(TypedTensorDict):
...     a: Tensor
...     b: Tensor
>>>
>>> ttd = MyTTD(a=torch.randn(4, 3), b=torch.randn(4, 5), batch_size=[4])
>>> tc_typed = MyTC.from_tensordict(ttd)
>>> tc_typed.a.shape
torch.Size([4, 3])

TypedTensorDict + backends

TypedTensorDict.from_tensordict(td) accepts any TensorDictBase subclass, just like TensorClass. The backend is stored live (no copy) – mutations through the TypedTensorDict go directly to the underlying backend.

from tensordict import TypedTensorDict
from torch import Tensor

class State(TypedTensorDict):
    x: Tensor
    y: Tensor

state = State.from_tensordict(some_backend)

Backend

Build

Read

Write

Index

Clone

Stack

Iter

Update

TensorDict

yes

yes

yes

yes

yes

yes

yes

yes

PersistentTensorDict (H5)

yes

yes

yes

yes

yes

yes

yes

yes

TensorDictStore (Redis)

yes

yes

yes

yes

yes

yes

yes

yes

LazyStackedTensorDict

yes

yes

yes

yes

yes

yes

yes

yes

TensorDict (memmap)

yes

yes

set_()

yes

yes

yes

yes

update_()

Note

Memory-mapped TensorDicts are locked after memmap_(). Use set_() and update_() for in-place writes instead of attribute assignment or update().

Building a TypedTensorDict on each backend

In-memory TensorDict – the default (direct construction creates one internally):

>>> import torch
>>> from tensordict import TensorDict, TypedTensorDict
>>> from torch import Tensor
>>>
>>> class State(TypedTensorDict):
...     x: Tensor
...     y: Tensor
>>>
>>> state = State(x=torch.randn(4, 3), y=torch.randn(4, 5), batch_size=[4])
>>> state.x.shape
torch.Size([4, 3])

Wrapping an existing TensorDict via from_tensordict (zero-copy):

>>> td = TensorDict(x=torch.randn(4, 3), y=torch.randn(4, 5), batch_size=[4])
>>> state = State.from_tensordict(td)
>>> state.x.shape  # reads from td
torch.Size([4, 3])
>>> state.x = torch.ones(4, 3)  # writes to td
>>> (td["x"] == 1).all()
True

HDF5 (PersistentTensorDict):

>>> from tensordict import PersistentTensorDict
>>>
>>> h5 = PersistentTensorDict.from_h5("data.h5")
>>> state = State.from_tensordict(h5)
>>> state.x.shape  # reads from HDF5
torch.Size([4, 3])

Redis (TensorDictStore):

>>> from tensordict.store import TensorDictStore
>>>
>>> store = TensorDictStore.from_tensordict(td, host="localhost")
>>> state = State.from_tensordict(store)
>>> state.x.shape  # fetched from Redis
torch.Size([4, 3])

Lazy stack:

>>> from tensordict import lazy_stack
>>>
>>> tds = [TensorDict(x=torch.randn(3), y=torch.randn(5)) for _ in range(4)]
>>> ls = lazy_stack(tds, dim=0)
>>> state = State.from_tensordict(ls)
>>> state[0].x.shape
torch.Size([3])

Memory-mapped TensorDict:

>>> td_mmap = td.memmap_("/tmp/my_memmap")
>>> state = State.from_tensordict(td_mmap)
>>> state.x.shape
torch.Size([4, 3])
>>> # memmap TDs are locked -- use in-place operations:
>>> state.set_("x", torch.ones(4, 3))

Stacking TypedTensorDicts

Dense stacking with torch.stack preserves the TypedTensorDict subclass type:

>>> s1 = State(x=torch.randn(3), y=torch.randn(3), batch_size=[3])
>>> s2 = State(x=torch.randn(3), y=torch.randn(3), batch_size=[3])
>>> stacked = torch.stack([s1, s2], dim=0)
>>> stacked.batch_size
torch.Size([2, 3])

Lazy stacking also works. Indexing a LazyStackedTensorDict of TypedTensorDict instances preserves the subclass type:

>>> from tensordict._lazy import LazyStackedTensorDict
>>>
>>> ls = LazyStackedTensorDict(s1, s2, stack_dim=0)
>>> isinstance(ls[0], State)
True

Pre-allocating on Redis and filling iteratively

A common pattern for shared replay buffers or distributed data stores is to pre-allocate storage on a remote server (Redis / Dragonfly / KeyDB) and fill it one sample at a time, without ever loading the full dataset into RAM.

TensorDictStore.from_schema creates keys with known shapes and dtypes directly on the server using SETRANGE (zero-filled by the server; no tensor data passes through Python):

>>> import torch
>>> from tensordict import TensorDict, TypedTensorDict
>>> from tensordict.store import TensorDictStore
>>> from torch import Tensor
>>>
>>> class Replay(TypedTensorDict):
...     obs: Tensor
...     action: Tensor
...     reward: Tensor
>>>
>>> # Pre-allocate 100k entries directly on Redis -- no RAM used
>>> store = TensorDictStore.from_schema(
...     {"obs": ([84, 84, 3], torch.uint8),
...      "action": ([4], torch.float32),
...      "reward": ([], torch.float32)},
...     batch_size=[100_000],
...     host="redis-node",
... )
>>>
>>> # Wrap with typed access
>>> replay = Replay.from_tensordict(store)
>>>
>>> # Fill iteratively -- each write goes directly to Redis
>>> for i, sample in enumerate(data_stream):
...     replay[i] = Replay(
...         obs=sample.obs, action=sample.action, reward=sample.reward,
...         batch_size=[],
...     )

If the store is initially empty (no keys registered yet), use check=False to skip the key-presence validation and fill keys on the fly:

>>> store = TensorDictStore(batch_size=[100_000], host="redis-node")
>>> replay = Replay.from_tensordict(store, check=False)
>>>
>>> # First indexed write auto-creates each key via SETRANGE
>>> replay[0] = Replay(obs=obs_0, action=act_0, reward=r_0, batch_size=[])
>>> # Subsequent writes fill in the pre-allocated storage
>>> replay[1] = Replay(obs=obs_1, action=act_1, reward=r_1, batch_size=[])

TensorClass vs TypedTensorDict

Both enforce typed schemas and can wrap any TensorDictBase backend, but they differ architecturally:

Aspect

TensorClass

TypedTensorDict

Relationship to TensorDictBase

Wraps a TensorDictBase (HAS-A via TensorCollection)

Is a TensorDictBase (IS-A, delegates to _source)

Can wrap non-TensorDict backends

Yes (H5, Redis, lazy stack, etc.)

Yes (H5, Redis, lazy stack, etc.)

**state spreading

Field-by-field repacking

Natively (MutableMapping)

Non-tensor fields

Supported

Not supported (tensor-only)

Backend stays live

Yes (writes go to original backend)

Yes (writes go to original backend)

Python inheritance

Not supported

Supported (standard class hierarchy)

Composable with each other

Yes (TC.from_tensordict(ttd) works)

Yes (TTD.from_tensordict(tc._tensordict) works)

Both wrappers keep the backend alive – mutations through the typed wrapper go directly to the underlying storage. Direct construction (without from_tensordict) creates an in-memory TensorDict as the backend.

Choose TensorClass when you need non-tensor fields or want to integrate with existing tensorclass-based APIs. Choose TypedTensorDict when you want native **state spreading, standard Python inheritance for schema composition, and full TensorDictBase API compatibility.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources