Shortcuts

Tree

class torchrl.data.Tree(count: 'int | torch.Tensor' = None, wins: 'int | torch.Tensor' = None, index: 'torch.Tensor | None' = None, hash: 'int | None' = None, node_id: 'int | None' = None, rollout: 'TensorDict | None' = None, node_data: 'TensorDict | None' = None, subtree: 'Tree' = None, _parent: 'weakref.ref | list[weakref.ref] | None' = None, specs: 'Composite | None' = None, *, batch_size, device=None, names=None)[source]
property branching_action: torch.Tensor | TensorDictBase | None

Returns the action that branched out to this particular node.

Returns:

a tensor, tensordict or None if the node has no parent.

See also

This will be equal to prev_action whenever the rollout data contains a single step.

See also

All actions associated with a given node (or observation) in the tree.

cat(dim: int = 0, *, out=None)

Concatenates tensordicts into a single tensordict along the given dimension.

This call is equivalent to calling torch.cat() but is compatible with torch.compile.

property device: device

Retrieves the device type of tensor class.

dumps(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False, robust_key: bool | None = None) Self

Saves the tensordict to disk.

This function is a proxy to memmap().

edges() list[tuple[int, int]][source]

Retrieves a list of edges in the tree.

Each edge is represented as a tuple of two node IDs: the parent node ID and the child node ID. The tree is traversed using Breadth-First Search (BFS) to ensure all edges are visited.

Returns:

A list of tuples, where each tuple contains a parent node ID and a child node ID.

classmethod fields()

Return a tuple describing the fields of this dataclass.

Accepts a dataclass or an instance of one. Tuple elements are of type Field.

from_any(*, auto_batch_size: bool = False, batch_dims: int | None = None, device: torch.device | None = None, batch_size: torch.Size | None = None)

Recursively converts any object to a TensorDict.

Note

from_any is less restrictive than the regular TensorDict constructor. It can cast data structures like dataclasses or tuples to a tensordict using custom heuristics. This approach may incur some extra overhead and involves more opinionated choices in terms of mapping strategies.

Note

This method recursively converts the input object to a TensorDict. If the object is already a TensorDict (or any similar tensor collection object), it will be returned as is.

Parameters:

obj – The object to be converted.

Keyword Arguments:
  • auto_batch_size (bool, optional) – if True, the batch size will be computed automatically. Defaults to False.

  • batch_dims (int, optional) – If auto_batch_size is True, defines how many dimensions the output tensordict should have. Defaults to None (full batch-size at each level).

  • device (torch.device, optional) – The device on which the TensorDict will be created.

  • batch_size (torch.Size, optional) – The batch size of the TensorDict. Exclusive with auto_batch_size.

Returns:

A TensorDict representation of the input object.

Supported objects:

from_dataclass(*, dest_cls: Type | None = None, auto_batch_size: bool = False, batch_dims: int | None = None, as_tensorclass: bool = False, device: torch.device | None = None, batch_size: torch.Size | None = None)

Converts a dataclass into a TensorDict instance.

Parameters:

dataclass – The dataclass instance to be converted.

Keyword Arguments:
  • dest_cls (tensorclass, optional) – A tensorclass type to be used to map the data. If not provided, a new class is created. Without effect if obj is a type or as_tensorclass is False.

  • auto_batch_size (bool, optional) – If True, automatically determines and applies batch size to the resulting TensorDict. Defaults to False.

  • batch_dims (int, optional) – If auto_batch_size is True, defines how many dimensions the output tensordict should have. Defaults to None (full batch-size at each level).

  • as_tensorclass (bool, optional) – If True, delegates the conversion to the free function from_dataclass() and returns a tensor-compatible class (tensorclass()) or instance instead of a TensorDict. Defaults to False.

  • device (torch.device, optional) – The device on which the TensorDict will be created. Defaults to None.

  • batch_size (torch.Size, optional) – The batch size of the TensorDict. Defaults to None.

Returns:

A TensorDict instance derived from the provided dataclass, unless as_tensorclass is True, in which case a tensor-compatible class or instance is returned.

Raises:

TypeError – If the provided input is not a dataclass instance.

Warning

This method is distinct from the free function from_dataclass and serves a different purpose. While the free function returns a tensor-compatible class or instance, this method returns a TensorDict instance.

Note

  • This method creates a new TensorDict instance with keys corresponding to the fields of the input dataclass.

  • Each key in the resulting TensorDict is initialized using the cls.from_any method.

  • The auto_batch_size option allows for automatic batch size determination and application to the resulting TensorDict.

from_h5(*, mode: str = 'r', auto_batch_size: bool = False, batch_dims: int | None = None, batch_size: torch.Size | None = None)

Creates a PersistentTensorDict from a h5 file.

Parameters:

filename (str) – The path to the h5 file.

Keyword Arguments:
  • mode (str, optional) – Reading mode. Defaults to "r".

  • auto_batch_size (bool, optional) – If True, the batch size will be computed automatically. Defaults to False.

  • batch_dims (int, optional) – If auto_batch_size is True, defines how many dimensions the output tensordict should have. Defaults to None (full batch-size at each level).

  • batch_size (torch.Size, optional) – The batch size of the TensorDict. Defaults to None.

Returns:

A PersistentTensorDict representation of the input h5 file.

Examples

>>> td = TensorDict.from_h5("path/to/file.h5")
>>> print(td)
PersistentTensorDict(
    fields={
        key1: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        key2: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
from_modules(*, as_module: bool = False, lock: bool = True, use_state_dict: bool = False, lazy_stack: bool = False, expand_identical: bool = False)

Retrieves the parameters of several modules for ensebmle learning/feature of expects applications through vmap.

Parameters:

modules (sequence of nn.Module) – the modules to get the parameters from. If the modules differ in their structure, a lazy stack is needed (see the lazy_stack argument below).

Keyword Arguments:
  • as_module (bool, optional) – if True, a TensorDictParams instance will be returned which can be used to store parameters within a torch.nn.Module. Defaults to False.

  • lock (bool, optional) – if True, the resulting tensordict will be locked. Defaults to True.

  • use_state_dict (bool, optional) –

    if True, the state-dict from the module will be used and unflattened into a TensorDict with the tree structure of the model. Defaults to False.

    Note

    This is particularly useful when state-dict hooks have to be used.

  • lazy_stack (bool, optional) –

    whether parameters should be densly or lazily stacked. Defaults to False (dense stack).

    Note

    lazy_stack and as_module are exclusive features.

    Warning

    There is a crucial difference between lazy and non-lazy outputs in that non-lazy output will reinstantiate parameters with the desired batch-size, while lazy_stack will just represent the parameters as lazily stacked. This means that whilst the original parameters can safely be passed to an optimizer when lazy_stack=True, the new parameters need to be passed when it is set to True.

    Warning

    Whilst it can be tempting to use a lazy stack to keep the orignal parameter references, remember that lazy stack perform a stack each time get() is called. This will require memory (N times the size of the parameters, more if a graph is built) and time to be computed. It also means that the optimizer(s) will contain more parameters, and operations like step() or zero_grad() will take longer to be executed. In general, lazy_stack should be reserved to very few use cases.

  • expand_identical (bool, optional) – if True and the same parameter (same identity) is being stacked to itself, an expanded version of this parameter will be returned instead. This argument is ignored when lazy_stack=True.

Examples

>>> from torch import nn
>>> from tensordict import TensorDict
>>> torch.manual_seed(0)
>>> empty_module = nn.Linear(3, 4, device="meta")
>>> n_models = 2
>>> modules = [nn.Linear(3, 4) for _ in range(n_models)]
>>> params = TensorDict.from_modules(*modules)
>>> print(params)
TensorDict(
    fields={
        bias: Parameter(shape=torch.Size([2, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        weight: Parameter(shape=torch.Size([2, 4, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)
>>> # example of batch execution
>>> def exec_module(params, x):
...     with params.to_module(empty_module):
...         return empty_module(x)
>>> x = torch.randn(3)
>>> y = torch.vmap(exec_module, (0, None))(params, x)
>>> assert y.shape == (n_models, 4)
>>> # since lazy_stack = False, backprop leaves the original params untouched
>>> y.sum().backward()
>>> assert params["weight"].grad.norm() > 0
>>> assert modules[0].weight.grad is None

With lazy_stack=True, things are slightly different:

>>> params = TensorDict.from_modules(*modules, lazy_stack=True)
>>> print(params)
LazyStackedTensorDict(
    fields={
        bias: Tensor(shape=torch.Size([2, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        weight: Tensor(shape=torch.Size([2, 4, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
    exclusive_fields={
    },
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False,
    stack_dim=0)
>>> # example of batch execution
>>> y = torch.vmap(exec_module, (0, None))(params, x)
>>> assert y.shape == (n_models, 4)
>>> y.sum().backward()
>>> assert modules[0].weight.grad is not None
from_namedtuple(*, auto_batch_size: bool = False, batch_dims: int | None = None, device: torch.device | None = None, batch_size: torch.Size | None = None)

Converts a namedtuple to a TensorDict recursively.

Parameters:

named_tuple – The namedtuple instance to be converted.

Keyword Arguments:
  • auto_batch_size (bool, optional) – if True, the batch size will be computed automatically. Defaults to False.

  • batch_dims (int, optional) – If auto_batch_size is True, defines how many dimensions the output tensordict should have. Defaults to None (full batch-size at each level).

  • device (torch.device, optional) – The device on which the TensorDict will be created. Defaults to None.

  • batch_size (torch.Size, optional) – The batch size of the TensorDict. Defaults to None.

Returns:

A TensorDict representation of the input namedtuple.

Examples

>>> from tensordict import TensorDict
>>> import torch
>>> data = TensorDict({
...     "a_tensor": torch.zeros((3)),
...     "nested": {"a_tensor": torch.zeros((3)), "a_string": "zero!"}}, [3])
>>> nt = data.to_namedtuple()
>>> print(nt)
GenericDict(a_tensor=tensor([0., 0., 0.]), nested=GenericDict(a_tensor=tensor([0., 0., 0.]), a_string='zero!'))
>>> TensorDict.from_namedtuple(nt, auto_batch_size=True)
TensorDict(
    fields={
        a_tensor: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        nested: TensorDict(
            fields={
                a_string: NonTensorData(data=zero!, batch_size=torch.Size([3]), device=None),
                a_tensor: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([3]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
from_pytree(*, batch_size: torch.Size | None = None, auto_batch_size: bool = False, batch_dims: int | None = None)

Converts a pytree to a TensorDict instance.

This method is designed to keep the pytree nested structure as much as possible.

Additional non-tensor keys are added to keep track of each level’s identity, providing a built-in pytree-to-tensordict bijective transform API.

Accepted classes currently include lists, tuples, named tuples and dict.

Note

For dictionaries, non-NestedKey keys are registered separately as NonTensorData instances.

Note

Tensor-castable types (such as int, float or np.ndarray) will be converted to torch.Tensor instances. Note that this transformation is surjective: transforming back the tensordict to a pytree will not recover the original types.

Examples

>>> # Create a pytree with tensor leaves, and one "weird"-looking dict key
>>> class WeirdLookingClass:
...     pass
...
>>> weird_key = WeirdLookingClass()
>>> # Make a pytree with tuple, lists, dict and namedtuple
>>> pytree = (
...     [torch.randint(10, (3,)), torch.zeros(2)],
...     {
...         "tensor": torch.randn(
...             2,
...         ),
...         "td": TensorDict({"one": 1}),
...         weird_key: torch.randint(10, (2,)),
...         "list": [1, 2, 3],
...     },
...     {"named_tuple": TensorDict({"two": torch.ones(1) * 2}).to_namedtuple()},
... )
>>> # Build a TensorDict from that pytree
>>> td = TensorDict.from_pytree(pytree)
>>> # Recover the pytree
>>> pytree_recon = td.to_pytree()
>>> # Check that the leaves match
>>> def check(v1, v2):
>>>     assert (v1 == v2).all()
>>>
>>> torch.utils._pytree.tree_map(check, pytree, pytree_recon)
>>> assert weird_key in pytree_recon[1]
from_remote_init(group: 'ProcessGroup' | None = None, device: torch.device | None = None) Self

Creates a new tensordict instance initialized from remotely sent metadata.

This class method receives the metadata sent by init_remote, creates a new tensordict with matching shape and dtype, and then asynchronously receives the actual tensordict content.

Parameters:
  • src (int) – The rank of the source process that sent the metadata.

  • group ("ProcessGroup", optional) – The process group to use for communication. Defaults to None.

  • device (torch.device, optional) – The device to use for tensor operations. Defaults to None.

Returns:

A new tensordict instance initialized with the received metadata and content.

Return type:

TensorDict

See also

The sending process should have called ~.init_remote to send the metadata and content.

from_struct_array(*, auto_batch_size: bool = False, batch_dims: int | None = None, device: torch.device | None = None, batch_size: torch.Size | None = None) Self

Converts a structured numpy array to a TensorDict.

The resulting TensorDict will share the same memory content as the numpy array (it is a zero-copy operation). Changing values of the structured numpy array in-place will affect the content of the TensorDict.

Note

This method performs a zero-copy operation, meaning that the resulting TensorDict will share the same memory content as the input numpy array. Therefore, changing values of the numpy array in-place will affect the content of the TensorDict.

Parameters:

struct_array (np.ndarray) – The structured numpy array to be converted.

Keyword Arguments:
  • auto_batch_size (bool, optional) – If True, the batch size will be computed automatically. Defaults to False.

  • batch_dims (int, optional) – If auto_batch_size is True, defines how many dimensions the output tensordict should have. Defaults to None (full batch-size at each level).

  • device (torch.device, optional) –

    The device on which the TensorDict will be created. Defaults to None.

    Note

    Changing the device (i.e., specifying any device other than None or "cpu") will transfer the data, resulting in a change to the memory location of the returned data.

  • batch_size (torch.Size, optional) – The batch size of the TensorDict. Defaults to None.

Returns:

A TensorDict representation of the input structured numpy array.

Examples

>>> x = np.array(
...     [("Rex", 9, 81.0), ("Fido", 3, 27.0)],
...     dtype=[("name", "U10"), ("age", "i4"), ("weight", "f4")],
... )
>>> td = TensorDict.from_struct_array(x)
>>> x_recon = td.to_struct_array()
>>> assert (x_recon == x).all()
>>> assert x_recon.shape == x.shape
>>> # Try modifying x age field and check effect on td
>>> x["age"] += 1
>>> assert (td["age"] == np.array([10, 4])).all()
classmethod from_tensordict(tensordict: TensorDictBase, non_tensordict: dict | None = None, safe: bool = True) Self

Tensor class wrapper to instantiate a new tensor class object.

Parameters:
  • tensordict (TensorDictBase) – Dictionary of tensor types

  • non_tensordict (dict) – Dictionary with non-tensor and nested tensor class objects

  • safe (bool) – Whether to raise an error if the tensordict is not a TensorDictBase instance

from_tuple(*, auto_batch_size: bool = False, batch_dims: int | None = None, device: torch.device | None = None, batch_size: torch.Size | None = None)

Converts a tuple to a TensorDict.

Parameters:

obj – The tuple instance to be converted.

Keyword Arguments:
  • auto_batch_size (bool, optional) – If True, the batch size will be computed automatically. Defaults to False.

  • batch_dims (int, optional) – If auto_batch_size is True, defines how many dimensions the output tensordict should have. Defaults to None (full batch-size at each level).

  • device (torch.device, optional) – The device on which the TensorDict will be created. Defaults to None.

  • batch_size (torch.Size, optional) – The batch size of the TensorDict. Defaults to None.

Returns:

A TensorDict representation of the input tuple.

Examples

>>> my_tuple = (1, 2, 3)
>>> td = TensorDict.from_tuple(my_tuple)
>>> print(td)
TensorDict(
    fields={
        0: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
        1: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
        2: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
fromkeys(value: Any = 0)

Creates a tensordict from a list of keys and a single value.

Parameters:
  • keys (list of NestedKey) – An iterable specifying the keys of the new dictionary.

  • value (compatible type, optional) – The value for all keys. Defaults to 0.

property full_action_spec

The action spec of the tree.

This is an alias for Tree.specs[‘input_spec’, ‘full_action_spec’].

property full_done_spec

The done spec of the tree.

This is an alias for Tree.specs[‘output_spec’, ‘full_done_spec’].

property full_observation_spec

The observation spec of the tree.

This is an alias for Tree.specs[‘output_spec’, ‘full_observation_spec’].

property full_reward_spec

The reward spec of the tree.

This is an alias for Tree.specs[‘output_spec’, ‘full_reward_spec’].

property full_state_spec

The state spec of the tree.

This is an alias for Tree.specs[‘input_spec’, ‘full_state_spec’].

fully_expanded(env: EnvBase) bool[source]

Returns True if the number of children is equal to the environment cardinality.

get(key: NestedKey, *args, **kwargs)

Gets the value stored with the input key.

Parameters:
  • key (str, tuple of str) – key to be queried. If tuple of str it is equivalent to chained calls of getattr.

  • default – default value if the key is not found in the tensorclass.

Returns:

value stored with the input key

get_vertex_by_hash(hash: int) Tree[source]

Goes through the tree and returns the node corresponding the given hash.

get_vertex_by_id(id: int) Tree[source]

Goes through the tree and returns the node corresponding the given id.

property is_terminal: bool | torch.Tensor

Returns True if the tree has no children nodes.

lazy_stack(dim: int = 0, *, out=None, **kwargs)

Creates a lazy stack of tensordicts.

See lazy_stack() for details.

load(*args, **kwargs) Self

Loads a tensordict from disk.

This class method is a proxy to load_memmap().

load_(prefix: str | Path, *args, **kwargs)

Loads a tensordict from disk within the current tensordict.

This class method is a proxy to load_memmap_().

load_memmap(device: torch.device | None = None, non_blocking: bool = False, *, out: TensorDictBase | None = None, robust_key: bool | None = None) Self

Loads a memory-mapped tensordict from disk.

Parameters:
  • prefix (str or Path to folder) – the path to the folder where the saved tensordict should be fetched.

  • device (torch.device or equivalent, optional) – if provided, the data will be asynchronously cast to that device. Supports “meta” device, in which case the data isn’t loaded but a set of empty “meta” tensors are created. This is useful to get a sense of the total model size and structure without actually opening any file.

  • non_blocking (bool, optional) – if True, synchronize won’t be called after loading tensors on device. Defaults to False.

  • out (TensorDictBase, optional) – optional tensordict where the data should be written.

  • robust_key (bool, optional) – if True, expects robust key encoding was used when saving and decodes filenames accordingly. If False, uses legacy behavior. If None (default), emits a deprecation warning and falls back to legacy behavior. Will default to True in v0.12.

Examples

>>> from tensordict import TensorDict
>>> td = TensorDict.fromkeys(["a", "b", "c", ("nested", "e")], 0)
>>> td.memmap("./saved_td")
>>> td_load = TensorDict.load_memmap("./saved_td")
>>> assert (td == td_load).all()

This method also allows loading nested tensordicts.

Examples

>>> nested = TensorDict.load_memmap("./saved_td/nested")
>>> assert nested["e"] == 0

A tensordict can also be loaded on “meta” device or, alternatively, as a fake tensor.

Examples

>>> import tempfile
>>> td = TensorDict({"a": torch.zeros(()), "b": {"c": torch.zeros(())}})
>>> with tempfile.TemporaryDirectory() as path:
...     td.save(path)
...     td_load = TensorDict.load_memmap(path, device="meta")
...     print("meta:", td_load)
...     from torch._subclasses import FakeTensorMode
...     with FakeTensorMode():
...         td_load = TensorDict.load_memmap(path)
...         print("fake:", td_load)
meta: TensorDict(
    fields={
        a: Tensor(shape=torch.Size([]), device=meta, dtype=torch.float32, is_shared=False),
        b: TensorDict(
            fields={
                c: Tensor(shape=torch.Size([]), device=meta, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=meta,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=meta,
    is_shared=False)
fake: TensorDict(
    fields={
        a: FakeTensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        b: TensorDict(
            fields={
                c: FakeTensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=cpu,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=cpu,
    is_shared=False)
load_state_dict(state_dict: dict[str, Any], strict=True, assign=False, from_flatten=False)

Loads a state_dict attemptedly in-place on the destination tensorclass.

classmethod make_node(data: TensorDictBase, *, device: torch.device | None = None, batch_size: torch.Size | None = None, specs: Composite | None = None) Tree[source]

Creates a new node given some data.

max_length()[source]

Returns the maximum length of all valid paths in the tree.

The length of a path is defined as the number of nodes in the path. If the tree is empty, returns 0.

Returns:

The maximum length of all valid paths in the tree.

Return type:

int

maybe_dense_stack(dim: int = 0, *, out=None, **kwargs)

Attempts to make a dense stack of tensordicts, and falls back on lazy stack when required..

See maybe_dense_stack() for details.

memmap(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False, existsok: bool = True, robust_key: bool | None = None) Self

Writes all tensors onto a corresponding memory-mapped Tensor in a new tensordict.

Parameters:
  • prefix (str) – directory prefix where the memory-mapped tensors will be stored. The directory tree structure will mimic the tensordict’s.

  • copy_existing (bool) – If False (default), an exception will be raised if an entry in the tensordict is already a tensor stored on disk with an associated file, but is not saved in the correct location according to prefix. If True, any existing Tensor will be copied to the new location.

Keyword Arguments:
  • num_threads (int, optional) – the number of threads used to write the memmap tensors. Defaults to 0.

  • return_early (bool, optional) – if True and num_threads>0, the method will return a future of the tensordict.

  • share_non_tensor (bool, optional) – if True, the non-tensor data will be shared between the processes and writing operation (such as inplace update or set) on any of the workers within a single node will update the value on all other workers. If the number of non_tensor leaves is high (e.g., sharing large stacks of non-tensor data) this may result in OOM or similar errors. Defaults to False.

  • existsok (bool, optional) – if False, an exception will be raised if a tensor already exists in the same path. Defaults to True.

  • robust_key (bool, optional) – if True, uses robust key encoding that safely handles keys with path separators and special characters. If False, uses legacy behavior (keys used as-is). If None (default), emits a deprecation warning and falls back to legacy behavior. Will default to True in v0.12.

The TensorDict is then locked, meaning that any writing operations that isn’t in-place will throw an exception (eg, rename, set or remove an entry). Once the tensordict is unlocked, the memory-mapped attribute is turned to False, because cross-process identity is not guaranteed anymore.

Returns:

A new tensordict with the tensors stored on disk if return_early=False, otherwise a TensorDictFuture instance.

Note

Serialising in this fashion might be slow with deeply nested tensordicts, so it is not recommended to call this method inside a training loop.

memmap_(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False, existsok: bool = True, robust_key: bool | None = None) Self

Writes all tensors onto a corresponding memory-mapped Tensor, in-place.

Parameters:
  • prefix (str) – directory prefix where the memory-mapped tensors will be stored. The directory tree structure will mimic the tensordict’s.

  • copy_existing (bool) – If False (default), an exception will be raised if an entry in the tensordict is already a tensor stored on disk with an associated file, but is not saved in the correct location according to prefix. If True, any existing Tensor will be copied to the new location.

Keyword Arguments:
  • num_threads (int, optional) – the number of threads used to write the memmap tensors. Defaults to 0.

  • return_early (bool, optional) – if True and num_threads>0, the method will return a future of the tensordict. The resulting tensordict can be queried using future.result().

  • share_non_tensor (bool, optional) – if True, the non-tensor data will be shared between the processes and writing operation (such as inplace update or set) on any of the workers within a single node will update the value on all other workers. If the number of non-tensor leaves is high (e.g., sharing large stacks of non-tensor data) this may result in OOM or similar errors. Defaults to False.

  • existsok (bool, optional) – if False, an exception will be raised if a tensor already exists in the same path. Defaults to True.

  • robust_key (bool, optional) – if True, uses robust key encoding that safely handles keys with path separators and special characters. If False, uses legacy behavior (keys used as-is). If None (default), emits a deprecation warning and falls back to legacy behavior. Will default to True in v0.12.

The TensorDict is then locked, meaning that any writing operations that isn’t in-place will throw an exception (eg, rename, set or remove an entry). Once the tensordict is unlocked, the memory-mapped attribute is turned to False, because cross-process identity is not guaranteed anymore.

Returns:

self if return_early=False, otherwise a TensorDictFuture instance.

Note

Serialising in this fashion might be slow with deeply nested tensordicts, so it is not recommended to call this method inside a training loop.

memmap_like(prefix: str | None = None, copy_existing: bool = False, *, existsok: bool = True, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False, robust_key: bool | None = None) Self

Creates a contentless Memory-mapped tensordict with the same shapes as the original one.

Parameters:
  • prefix (str) – directory prefix where the memory-mapped tensors will be stored. The directory tree structure will mimic the tensordict’s.

  • copy_existing (bool) – If False (default), an exception will be raised if an entry in the tensordict is already a tensor stored on disk with an associated file, but is not saved in the correct location according to prefix. If True, any existing Tensor will be copied to the new location.

Keyword Arguments:
  • num_threads (int, optional) – the number of threads used to write the memmap tensors. Defaults to 0.

  • return_early (bool, optional) – if True and num_threads>0, the method will return a future of the tensordict.

  • share_non_tensor (bool, optional) – if True, the non-tensor data will be shared between the processes and writing operation (such as inplace update or set) on any of the workers within a single node will update the value on all other workers. If the number of non-tensor leaves is high (e.g., sharing large stacks of non-tensor data) this may result in OOM or similar errors. Defaults to False.

  • existsok (bool, optional) – if False, an exception will be raised if a tensor already exists in the same path. Defaults to True.

  • robust_key (bool, optional) – if True, uses robust key encoding that safely handles keys with path separators and special characters. If False, uses legacy behavior (keys used as-is). If None (default), emits a deprecation warning and falls back to legacy behavior. Will default to True in v0.12.

The TensorDict is then locked, meaning that any writing operations that isn’t in-place will throw an exception (eg, rename, set or remove an entry). Once the tensordict is unlocked, the memory-mapped attribute is turned to False, because cross-process identity is not guaranteed anymore.

Returns:

A new TensorDict instance with data stored as memory-mapped tensors if return_early=False, otherwise a TensorDictFuture instance.

Note

This is the recommended method to write a set of large buffers on disk, as memmap_() will copy the information, which can be slow for large content.

Examples

>>> td = TensorDict({
...     "a": torch.zeros((3, 64, 64), dtype=torch.uint8),
...     "b": torch.zeros(1, dtype=torch.int64),
... }, batch_size=[]).expand(1_000_000)  # expand does not allocate new memory
>>> buffer = td.memmap_like("/path/to/dataset")
memmap_refresh_()

Refreshes the content of the memory-mapped tensordict if it has a saved_path.

This method will raise an exception if no path is associated with it.

property node_observation: torch.Tensor | TensorDictBase

Returns the observation associated with this particular node.

This is the observation (or bag of observations) that defines the node before a branching occurs. If the node contains a rollout() attribute, the node observation is typically identical to the observation resulting from the last action undertaken, i.e., node.rollout[..., -1]["next", "observation"].

If more than one observation key is associated with the tree specs, a TensorDict instance is returned instead.

For a more consistent representation, see node_observations.

property node_observations: torch.Tensor | TensorDictBase

Returns the observations associated with this particular node in a TensorDict format.

This is the observation (or bag of observations) that defines the node before a branching occurs. If the node contains a rollout() attribute, the node observation is typically identical to the observation resulting from the last action undertaken, i.e., node.rollout[..., -1]["next", "observation"].

If more than one observation key is associated with the tree specs, a TensorDict instance is returned instead.

For a more consistent representation, see node_observations.

property num_children: int

Number of children of this node.

Equates to the number of elements in the self.subtree stack.

num_vertices(*, count_repeat: bool = False) int[source]

Returns the number of unique vertices in the Tree.

Keyword Arguments:

count_repeat (bool, optional) –

Determines whether to count repeated vertices.

  • If False, counts each unique vertex only once.

  • If True, counts vertices multiple times if they appear in different paths.

Defaults to False.

Returns:

The number of unique vertices in the Tree.

Return type:

int

property parent: Tree | None

The parent of the node.

If the node has a parent and this object is still present in the python workspace, it will be returned by this property.

For re-branching trees, this property may return a stack of trees where every index of the stack corresponds to a different parent.

Note

the parent attribute will match in content but not in identity: the tensorclass object is recustructed using the same tensors (i.e., tensors that point to the same memory locations).

Returns:

A Tree containing the parent data or None if the parent data is out of scope or the node is the root.

plot(backend: str = 'plotly', figure: str = 'tree', info: list[str] = None, make_labels: Callable[[Any, ...], Any] | None = None)[source]

Plots a visualization of the tree using the specified backend and figure type.

Parameters:
  • backend – The plotting backend to use. Currently only supports ‘plotly’.

  • figure – The type of figure to plot. Can be either ‘tree’ or ‘box’.

  • info – A list of additional information to include in the plot (not currently used).

  • make_labels – An optional function to generate custom labels for the plot.

Raises:

NotImplementedError – If an unsupported backend or figure type is specified.

property prev_action: torch.Tensor | TensorDictBase | None

The action undertaken just before this node’s observation was generated.

Returns:

a tensor, tensordict or None if the node has no parent.

See also

This will be equal to branching_action whenever the rollout data contains a single step.

See also

All actions associated with a given node (or observation) in the tree.

rollout_from_path(path: tuple[int]) TensorDictBase | None[source]

Retrieves the rollout data along a given path in the tree.

The rollout data is concatenated along the last dimension (dim=-1) for each node in the path. If no rollout data is found along the path, returns None.

Parameters:

path – A tuple of integers representing the path in the tree.

Returns:

The concatenated rollout data along the path, or None if no data is found.

save(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False, robust_key: bool | None = None) Self

Saves the tensordict to disk.

This function is a proxy to memmap().

property selected_actions: torch.Tensor | TensorDictBase | None

Returns a tensor containing all the selected actions branching out from this node.

set(key: NestedKey, value: Any, inplace: bool = False, non_blocking: bool = False)

Sets a new key-value pair.

Parameters:
  • key (str, tuple of str) – name of the key to be set. If tuple of str it is equivalent to chained calls of getattr followed by a final setattr.

  • value (Any) – value to be stored in the tensorclass

  • inplace (bool, optional) – if True, set will tentatively try to update the value in-place. If False or if the key isn’t present, the value will be simply written at its destination.

Returns:

self

stack(dim: int = 0, *, out=None)

Stacks tensordicts into a single tensordict along the given dimension.

This call is equivalent to calling torch.stack() but is compatible with torch.compile.

state_dict(destination=None, prefix='', keep_vars=False, flatten=False) dict[str, Any]

Returns a state_dict dictionary that can be used to save and load data from a tensorclass.

to_string(node_format_fn=<function Tree.<lambda>>)[source]

Generates a string representation of the tree.

This function can pull out information from each of the nodes in a tree, so it can be useful for debugging. The nodes are listed line-by-line. Each line contains the path to the node, followed by the string representation of that node generated with :arg:`node_format_fn`. Each line is indented according to number of steps in the path required to get to the corresponding node.

Parameters:

node_format_fn (Callable, optional) – User-defined function to generate a string for each node of the tree. The signature must be (Tree) -> Any, and the output must be convertible to a string. If this argument is not given, the generated string is the node’s Tree.node_data attribute converted to a dict.

Examples

>>> from torchrl.data import MCTSForest
>>> from tensordict import TensorDict
>>> forest = MCTSForest()
>>> td_root = TensorDict({"observation": 0,})
>>> rollouts_data = [
...     # [(action, obs), ...]
...     [(3, 123), (1, 456)],
...     [(2, 359), (2, 3094)],
...     [(3, 123), (9, 392), (6, 989), (20, 809), (21, 847)],
...     [(1, 75)],
...     [(3, 123), (0, 948)],
...     [(2, 359), (2, 3094), (10, 68)],
...     [(2, 359), (2, 3094), (11, 9045)],
... ]
>>> for rollout_data in rollouts_data:
...     td = td_root.clone().unsqueeze(0)
...     for action, obs in rollout_data:
...         td = td.update(TensorDict({
...             "action": [action],
...             "next": TensorDict({"observation": [obs]}, [1]),
...         }, [1]))
...         forest.extend(td)
...         td = td["next"].clone()
...
>>> tree = forest.get_tree(td_root)
>>> print(tree.to_string())
(0,) {'observation': tensor(123)}
(0, 0) {'observation': tensor(456)}
(0, 1) {'observation': tensor(847)}
(0, 2) {'observation': tensor(948)}
(1,) {'observation': tensor(3094)}
(1, 0) {'observation': tensor(68)}
(1, 1) {'observation': tensor(9045)}
(2,) {'observation': tensor(75)}
to_tensordict(*, retain_none: bool | None = None) TensorDict

Convert the tensorclass into a regular TensorDict.

Makes a copy of all entries. Memmap and shared memory tensors are converted to regular tensors.

Parameters:

retain_none (bool) – if True, the None values will be written in the tensordict. Otherwise they will be discrarded. Default: True.

Returns:

A new TensorDict object containing the same values as the tensorclass.

unbind(dim: int)

Returns a tuple of indexed tensorclass instances unbound along the indicated dimension.

Resulting tensorclass instances will share the storage of the initial tensorclass instance.

valid_paths()[source]

Generates all valid paths in the tree.

A valid path is a sequence of child indices that starts at the root node and ends at a leaf node. Each path is represented as a tuple of integers, where each integer corresponds to the index of a child node.

Yields:

tuple – A valid path in the tree.

vertices(*, key_type: Literal['id', 'hash', 'path'] = 'hash') dict[int | tuple[int], Tree][source]

Returns a map containing the vertices of the Tree.

Keyword Arguments:

key_type (Literal["id", "hash", "path"], optional) –

Specifies the type of key to use for the vertices.

  • ”id”: Use the vertex ID as the key.

  • ”hash”: Use a hash of the vertex as the key.

  • ”path”: Use the path to the vertex as the key. This may lead to a dictionary with a longer length than

    when "id" or "hash" are used as the same node may be part of multiple trajectories. Defaults to "hash".

Defaults to an empty string, which may imply a default behavior.

Returns:

A dictionary mapping keys to Tree vertices.

Return type:

Dict[int | Tuple[int], Tree]

property visits: int | torch.Tensor

Returns the number of visits associated with this particular node.

This is an alias for the count attribute.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources