Cross-class compatibility¶
TensorDict has several typed container wrappers and storage backends that
can be composed together. This page documents which combinations work, which
have caveats, and how to build these “chimera” objects.
Architecture overview¶
TensorCollection
├── TensorDictBase
│ ├── TensorDict (in-memory)
│ ├── TypedTensorDict (typed fields, wraps any TensorDictBase)
│ ├── PersistentTensorDict (HDF5-backed)
│ ├── TensorDictStore (Redis / Dragonfly / KeyDB)
│ └── LazyStackedTensorDict (lazy stack of heterogeneous TDs)
│
└── TensorClass (typed wrapper, HAS-A TensorDictBase)
Two patterns exist for adding typed field declarations:
TensorClass wraps any
TensorDictBaseviafrom_tensordict(td). It delegates all storage to the wrapped object.TypedTensorDict wraps any
TensorDictBaseviafrom_tensordict(td), similar toTensorClass. Direct construction creates aTensorDictinternally. UnlikeTensorClass, it inherits fromTensorDictBasedirectly, supports**statespreading natively, and uses standard Python inheritance for schema composition.
TensorClass + backends¶
TensorClass.from_tensordict(td) accepts any TensorDictBase subclass.
The table below summarises which operations work on each combination.
from tensordict import TensorClass
from torch import Tensor
class MyTC(TensorClass):
a: Tensor
b: Tensor
tc = MyTC.from_tensordict(some_backend)
Backend |
Build |
Read |
Write |
Index |
Clone |
Stack |
Iter |
Update |
|---|---|---|---|---|---|---|---|---|
|
yes |
yes |
yes |
yes |
yes |
yes |
yes |
yes |
|
yes |
yes |
yes |
yes |
yes |
yes |
yes |
yes |
|
yes |
yes |
yes |
yes |
yes |
yes |
yes |
yes |
|
yes |
yes |
yes |
yes |
yes |
yes |
yes |
yes |
|
yes |
yes |
set_() |
yes |
yes |
yes |
yes |
update_() |
|
yes |
yes |
yes |
yes |
yes |
yes |
yes |
yes |
Note
Memory-mapped TensorDicts are locked after memmap_(). Use
set_() and update_() for in-place writes instead of attribute
assignment or update().
Building a TensorClass on each backend¶
In-memory TensorDict – the default and simplest case:
>>> import torch
>>> from tensordict import TensorDict, TensorClass
>>> from torch import Tensor
>>>
>>> class MyTC(TensorClass):
... a: Tensor
... b: Tensor
>>>
>>> td = TensorDict(a=torch.randn(4, 3), b=torch.randn(4, 5), batch_size=[4])
>>> tc = MyTC.from_tensordict(td)
>>> tc.a.shape
torch.Size([4, 3])
HDF5 (PersistentTensorDict):
>>> from tensordict import PersistentTensorDict
>>>
>>> h5 = PersistentTensorDict.from_dict(td, filename="data.h5")
>>> tc_h5 = MyTC.from_tensordict(h5)
>>> tc_h5.a.shape # reads from HDF5
torch.Size([4, 3])
Redis (TensorDictStore):
>>> from tensordict.store import TensorDictStore
>>>
>>> store = TensorDictStore.from_tensordict(td, host="localhost")
>>> tc_redis = MyTC.from_tensordict(store)
>>> tc_redis.a.shape # fetched from Redis
torch.Size([4, 3])
Lazy stack:
>>> from tensordict import lazy_stack
>>>
>>> tds = [TensorDict(a=torch.randn(3), b=torch.randn(5)) for _ in range(4)]
>>> ls = lazy_stack(tds, dim=0)
>>> tc_lazy = MyTC.from_tensordict(ls)
>>> tc_lazy[0].a.shape
torch.Size([3])
Memory-mapped TensorDict:
>>> td_mmap = td.memmap_("/tmp/my_memmap")
>>> tc_mmap = MyTC.from_tensordict(td_mmap)
>>> tc_mmap.a.shape
torch.Size([4, 3])
>>> # memmap TDs are locked -- use in-place operations:
>>> tc_mmap.set_("a", torch.ones(4, 3))
TypedTensorDict as backend – both TensorClass and TypedTensorDict enforce schemas, but they compose without conflict:
>>> from tensordict import TypedTensorDict
>>>
>>> class MyTTD(TypedTensorDict):
... a: Tensor
... b: Tensor
>>>
>>> ttd = MyTTD(a=torch.randn(4, 3), b=torch.randn(4, 5), batch_size=[4])
>>> tc_typed = MyTC.from_tensordict(ttd)
>>> tc_typed.a.shape
torch.Size([4, 3])
TypedTensorDict + backends¶
TypedTensorDict.from_tensordict(td) accepts any TensorDictBase subclass,
just like TensorClass. The backend is stored live (no copy) – mutations
through the TypedTensorDict go directly to the underlying backend.
from tensordict import TypedTensorDict
from torch import Tensor
class State(TypedTensorDict):
x: Tensor
y: Tensor
state = State.from_tensordict(some_backend)
Backend |
Build |
Read |
Write |
Index |
Clone |
Stack |
Iter |
Update |
|---|---|---|---|---|---|---|---|---|
|
yes |
yes |
yes |
yes |
yes |
yes |
yes |
yes |
|
yes |
yes |
yes |
yes |
yes |
yes |
yes |
yes |
|
yes |
yes |
yes |
yes |
yes |
yes |
yes |
yes |
|
yes |
yes |
yes |
yes |
yes |
yes |
yes |
yes |
|
yes |
yes |
set_() |
yes |
yes |
yes |
yes |
update_() |
Note
Memory-mapped TensorDicts are locked after memmap_(). Use
set_() and update_() for in-place writes instead of attribute
assignment or update().
Building a TypedTensorDict on each backend¶
In-memory TensorDict – the default (direct construction creates one internally):
>>> import torch
>>> from tensordict import TensorDict, TypedTensorDict
>>> from torch import Tensor
>>>
>>> class State(TypedTensorDict):
... x: Tensor
... y: Tensor
>>>
>>> state = State(x=torch.randn(4, 3), y=torch.randn(4, 5), batch_size=[4])
>>> state.x.shape
torch.Size([4, 3])
Wrapping an existing TensorDict via from_tensordict (zero-copy):
>>> td = TensorDict(x=torch.randn(4, 3), y=torch.randn(4, 5), batch_size=[4])
>>> state = State.from_tensordict(td)
>>> state.x.shape # reads from td
torch.Size([4, 3])
>>> state.x = torch.ones(4, 3) # writes to td
>>> (td["x"] == 1).all()
True
HDF5 (PersistentTensorDict):
>>> from tensordict import PersistentTensorDict
>>>
>>> h5 = PersistentTensorDict.from_h5("data.h5")
>>> state = State.from_tensordict(h5)
>>> state.x.shape # reads from HDF5
torch.Size([4, 3])
Redis (TensorDictStore):
>>> from tensordict.store import TensorDictStore
>>>
>>> store = TensorDictStore.from_tensordict(td, host="localhost")
>>> state = State.from_tensordict(store)
>>> state.x.shape # fetched from Redis
torch.Size([4, 3])
Lazy stack:
>>> from tensordict import lazy_stack
>>>
>>> tds = [TensorDict(x=torch.randn(3), y=torch.randn(5)) for _ in range(4)]
>>> ls = lazy_stack(tds, dim=0)
>>> state = State.from_tensordict(ls)
>>> state[0].x.shape
torch.Size([3])
Memory-mapped TensorDict:
>>> td_mmap = td.memmap_("/tmp/my_memmap")
>>> state = State.from_tensordict(td_mmap)
>>> state.x.shape
torch.Size([4, 3])
>>> # memmap TDs are locked -- use in-place operations:
>>> state.set_("x", torch.ones(4, 3))
Stacking TypedTensorDicts¶
Dense stacking with torch.stack preserves the TypedTensorDict subclass
type:
>>> s1 = State(x=torch.randn(3), y=torch.randn(3), batch_size=[3])
>>> s2 = State(x=torch.randn(3), y=torch.randn(3), batch_size=[3])
>>> stacked = torch.stack([s1, s2], dim=0)
>>> stacked.batch_size
torch.Size([2, 3])
Lazy stacking also works. Indexing a LazyStackedTensorDict of
TypedTensorDict instances preserves the subclass type:
>>> from tensordict._lazy import LazyStackedTensorDict
>>>
>>> ls = LazyStackedTensorDict(s1, s2, stack_dim=0)
>>> isinstance(ls[0], State)
True
Pre-allocating on Redis and filling iteratively¶
A common pattern for shared replay buffers or distributed data stores is to pre-allocate storage on a remote server (Redis / Dragonfly / KeyDB) and fill it one sample at a time, without ever loading the full dataset into RAM.
TensorDictStore.from_schema creates keys with known shapes and dtypes
directly on the server using SETRANGE (zero-filled by the server; no
tensor data passes through Python):
>>> import torch
>>> from tensordict import TensorDict, TypedTensorDict
>>> from tensordict.store import TensorDictStore
>>> from torch import Tensor
>>>
>>> class Replay(TypedTensorDict):
... obs: Tensor
... action: Tensor
... reward: Tensor
>>>
>>> # Pre-allocate 100k entries directly on Redis -- no RAM used
>>> store = TensorDictStore.from_schema(
... {"obs": ([84, 84, 3], torch.uint8),
... "action": ([4], torch.float32),
... "reward": ([], torch.float32)},
... batch_size=[100_000],
... host="redis-node",
... )
>>>
>>> # Wrap with typed access
>>> replay = Replay.from_tensordict(store)
>>>
>>> # Fill iteratively -- each write goes directly to Redis
>>> for i, sample in enumerate(data_stream):
... replay[i] = Replay(
... obs=sample.obs, action=sample.action, reward=sample.reward,
... batch_size=[],
... )
If the store is initially empty (no keys registered yet), use check=False
to skip the key-presence validation and fill keys on the fly:
>>> store = TensorDictStore(batch_size=[100_000], host="redis-node")
>>> replay = Replay.from_tensordict(store, check=False)
>>>
>>> # First indexed write auto-creates each key via SETRANGE
>>> replay[0] = Replay(obs=obs_0, action=act_0, reward=r_0, batch_size=[])
>>> # Subsequent writes fill in the pre-allocated storage
>>> replay[1] = Replay(obs=obs_1, action=act_1, reward=r_1, batch_size=[])
TensorClass vs TypedTensorDict¶
Both enforce typed schemas and can wrap any TensorDictBase backend, but
they differ architecturally:
Aspect |
|
|
|---|---|---|
Relationship to |
Wraps a |
Is a |
Can wrap non-TensorDict backends |
Yes (H5, Redis, lazy stack, etc.) |
Yes (H5, Redis, lazy stack, etc.) |
|
Field-by-field repacking |
Natively ( |
Non-tensor fields |
Supported |
Not supported (tensor-only) |
Backend stays live |
Yes (writes go to original backend) |
Yes (writes go to original backend) |
Python inheritance |
Not supported |
Supported (standard class hierarchy) |
Composable with each other |
Yes ( |
Yes ( |
Both wrappers keep the backend alive – mutations through the typed wrapper go
directly to the underlying storage. Direct construction (without
from_tensordict) creates an in-memory TensorDict as the backend.
Choose TensorClass when you need non-tensor fields or want to integrate
with existing tensorclass-based APIs. Choose TypedTensorDict when you
want native **state spreading, standard Python inheritance for schema
composition, and full TensorDictBase API compatibility.