Tree¶
- class torchrl.data.Tree(count: 'int | torch.Tensor' = None, wins: 'int | torch.Tensor' = None, index: 'torch.Tensor | None' = None, hash: 'int | None' = None, node_id: 'int | None' = None, rollout: 'TensorDict | None' = None, node_data: 'TensorDict | None' = None, subtree: 'Tree' = None, _parent: 'weakref.ref | list[weakref.ref] | None' = None, specs: 'Composite | None' = None, *, batch_size, device=None, names=None)[source]¶
- property branching_action: torch.Tensor | TensorDictBase | None¶
Returns the action that branched out to this particular node.
- Returns:
a tensor, tensordict or None if the node has no parent.
See also
This will be equal to
prev_action
whenever the rollout data contains a single step.See also
All actions associated with a given node (or observation) in the tree
.
- cat(dim: int = 0, *, out=None)¶
Concatenates tensordicts into a single tensordict along the given dimension.
This call is equivalent to calling
torch.cat()
but is compatible with torch.compile.
- dumps(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False, robust_key: bool | None = None) Self ¶
Saves the tensordict to disk.
This function is a proxy to
memmap()
.
- edges() list[tuple[int, int]] [source]¶
Retrieves a list of edges in the tree.
Each edge is represented as a tuple of two node IDs: the parent node ID and the child node ID. The tree is traversed using Breadth-First Search (BFS) to ensure all edges are visited.
- Returns:
A list of tuples, where each tuple contains a parent node ID and a child node ID.
- classmethod fields()¶
Return a tuple describing the fields of this dataclass.
Accepts a dataclass or an instance of one. Tuple elements are of type Field.
- from_any(*, auto_batch_size: bool = False, batch_dims: int | None = None, device: torch.device | None = None, batch_size: torch.Size | None = None)¶
Recursively converts any object to a TensorDict.
Note
from_any
is less restrictive than the regular TensorDict constructor. It can cast data structures like dataclasses or tuples to a tensordict using custom heuristics. This approach may incur some extra overhead and involves more opinionated choices in terms of mapping strategies.Note
This method recursively converts the input object to a TensorDict. If the object is already a TensorDict (or any similar tensor collection object), it will be returned as is.
- Parameters:
obj – The object to be converted.
- Keyword Arguments:
auto_batch_size (bool, optional) – if
True
, the batch size will be computed automatically. Defaults toFalse
.batch_dims (int, optional) – If auto_batch_size is
True
, defines how many dimensions the output tensordict should have. Defaults toNone
(full batch-size at each level).device (torch.device, optional) – The device on which the TensorDict will be created.
batch_size (torch.Size, optional) – The batch size of the TensorDict. Exclusive with
auto_batch_size
.
- Returns:
A TensorDict representation of the input object.
Supported objects:
Dataclasses through
from_dataclass()
(dataclasses will be converted to TensorDict instances, not tensorclasses).Namedtuples through
from_namedtuple()
.Dictionaries through
from_dict()
.Tuples through
from_tuple()
.NumPy’s structured arrays through
from_struct_array()
.HDF5 objects through
from_h5()
.
- from_dataclass(*, dest_cls: Type | None = None, auto_batch_size: bool = False, batch_dims: int | None = None, as_tensorclass: bool = False, device: torch.device | None = None, batch_size: torch.Size | None = None)¶
Converts a dataclass into a TensorDict instance.
- Parameters:
dataclass – The dataclass instance to be converted.
- Keyword Arguments:
dest_cls (tensorclass, optional) – A tensorclass type to be used to map the data. If not provided, a new class is created. Without effect if
obj
is a type or as_tensorclass is False.auto_batch_size (bool, optional) – If
True
, automatically determines and applies batch size to the resulting TensorDict. Defaults toFalse
.batch_dims (int, optional) – If
auto_batch_size
isTrue
, defines how many dimensions the output tensordict should have. Defaults toNone
(full batch-size at each level).as_tensorclass (bool, optional) – If
True
, delegates the conversion to the free functionfrom_dataclass()
and returns a tensor-compatible class (tensorclass()
) or instance instead of a TensorDict. Defaults toFalse
.device (torch.device, optional) – The device on which the TensorDict will be created. Defaults to
None
.batch_size (torch.Size, optional) – The batch size of the TensorDict. Defaults to
None
.
- Returns:
A TensorDict instance derived from the provided dataclass, unless as_tensorclass is True, in which case a tensor-compatible class or instance is returned.
- Raises:
TypeError – If the provided input is not a dataclass instance.
Warning
This method is distinct from the free function from_dataclass and serves a different purpose. While the free function returns a tensor-compatible class or instance, this method returns a TensorDict instance.
Note
This method creates a new TensorDict instance with keys corresponding to the fields of the input dataclass.
Each key in the resulting TensorDict is initialized using the cls.from_any method.
The auto_batch_size option allows for automatic batch size determination and application to the resulting TensorDict.
- from_h5(*, mode: str = 'r', auto_batch_size: bool = False, batch_dims: int | None = None, batch_size: torch.Size | None = None)¶
Creates a PersistentTensorDict from a h5 file.
- Parameters:
filename (str) – The path to the h5 file.
- Keyword Arguments:
mode (str, optional) – Reading mode. Defaults to
"r"
.auto_batch_size (bool, optional) – If
True
, the batch size will be computed automatically. Defaults toFalse
.batch_dims (int, optional) – If auto_batch_size is
True
, defines how many dimensions the output tensordict should have. Defaults toNone
(full batch-size at each level).batch_size (torch.Size, optional) – The batch size of the TensorDict. Defaults to
None
.
- Returns:
A PersistentTensorDict representation of the input h5 file.
Examples
>>> td = TensorDict.from_h5("path/to/file.h5") >>> print(td) PersistentTensorDict( fields={ key1: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), key2: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
- from_modules(*, as_module: bool = False, lock: bool = True, use_state_dict: bool = False, lazy_stack: bool = False, expand_identical: bool = False)¶
Retrieves the parameters of several modules for ensebmle learning/feature of expects applications through vmap.
- Parameters:
modules (sequence of nn.Module) – the modules to get the parameters from. If the modules differ in their structure, a lazy stack is needed (see the
lazy_stack
argument below).- Keyword Arguments:
as_module (bool, optional) – if
True
, aTensorDictParams
instance will be returned which can be used to store parameters within atorch.nn.Module
. Defaults toFalse
.lock (bool, optional) – if
True
, the resulting tensordict will be locked. Defaults toTrue
.use_state_dict (bool, optional) –
if
True
, the state-dict from the module will be used and unflattened into a TensorDict with the tree structure of the model. Defaults toFalse
.Note
This is particularly useful when state-dict hooks have to be used.
lazy_stack (bool, optional) –
whether parameters should be densly or lazily stacked. Defaults to
False
(dense stack).Note
lazy_stack
andas_module
are exclusive features.Warning
There is a crucial difference between lazy and non-lazy outputs in that non-lazy output will reinstantiate parameters with the desired batch-size, while
lazy_stack
will just represent the parameters as lazily stacked. This means that whilst the original parameters can safely be passed to an optimizer whenlazy_stack=True
, the new parameters need to be passed when it is set toTrue
.Warning
Whilst it can be tempting to use a lazy stack to keep the orignal parameter references, remember that lazy stack perform a stack each time
get()
is called. This will require memory (N times the size of the parameters, more if a graph is built) and time to be computed. It also means that the optimizer(s) will contain more parameters, and operations likestep()
orzero_grad()
will take longer to be executed. In general,lazy_stack
should be reserved to very few use cases.expand_identical (bool, optional) – if
True
and the same parameter (same identity) is being stacked to itself, an expanded version of this parameter will be returned instead. This argument is ignored whenlazy_stack=True
.
Examples
>>> from torch import nn >>> from tensordict import TensorDict >>> torch.manual_seed(0) >>> empty_module = nn.Linear(3, 4, device="meta") >>> n_models = 2 >>> modules = [nn.Linear(3, 4) for _ in range(n_models)] >>> params = TensorDict.from_modules(*modules) >>> print(params) TensorDict( fields={ bias: Parameter(shape=torch.Size([2, 4]), device=cpu, dtype=torch.float32, is_shared=False), weight: Parameter(shape=torch.Size([2, 4, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([2]), device=None, is_shared=False) >>> # example of batch execution >>> def exec_module(params, x): ... with params.to_module(empty_module): ... return empty_module(x) >>> x = torch.randn(3) >>> y = torch.vmap(exec_module, (0, None))(params, x) >>> assert y.shape == (n_models, 4) >>> # since lazy_stack = False, backprop leaves the original params untouched >>> y.sum().backward() >>> assert params["weight"].grad.norm() > 0 >>> assert modules[0].weight.grad is None
With
lazy_stack=True
, things are slightly different:>>> params = TensorDict.from_modules(*modules, lazy_stack=True) >>> print(params) LazyStackedTensorDict( fields={ bias: Tensor(shape=torch.Size([2, 4]), device=cpu, dtype=torch.float32, is_shared=False), weight: Tensor(shape=torch.Size([2, 4, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, exclusive_fields={ }, batch_size=torch.Size([2]), device=None, is_shared=False, stack_dim=0) >>> # example of batch execution >>> y = torch.vmap(exec_module, (0, None))(params, x) >>> assert y.shape == (n_models, 4) >>> y.sum().backward() >>> assert modules[0].weight.grad is not None
- from_namedtuple(*, auto_batch_size: bool = False, batch_dims: int | None = None, device: torch.device | None = None, batch_size: torch.Size | None = None)¶
Converts a namedtuple to a TensorDict recursively.
- Parameters:
named_tuple – The namedtuple instance to be converted.
- Keyword Arguments:
auto_batch_size (bool, optional) – if
True
, the batch size will be computed automatically. Defaults toFalse
.batch_dims (int, optional) – If
auto_batch_size
isTrue
, defines how many dimensions the output tensordict should have. Defaults toNone
(full batch-size at each level).device (torch.device, optional) – The device on which the TensorDict will be created. Defaults to
None
.batch_size (torch.Size, optional) – The batch size of the TensorDict. Defaults to
None
.
- Returns:
A TensorDict representation of the input namedtuple.
Examples
>>> from tensordict import TensorDict >>> import torch >>> data = TensorDict({ ... "a_tensor": torch.zeros((3)), ... "nested": {"a_tensor": torch.zeros((3)), "a_string": "zero!"}}, [3]) >>> nt = data.to_namedtuple() >>> print(nt) GenericDict(a_tensor=tensor([0., 0., 0.]), nested=GenericDict(a_tensor=tensor([0., 0., 0.]), a_string='zero!')) >>> TensorDict.from_namedtuple(nt, auto_batch_size=True) TensorDict( fields={ a_tensor: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), nested: TensorDict( fields={ a_string: NonTensorData(data=zero!, batch_size=torch.Size([3]), device=None), a_tensor: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False)
- from_pytree(*, batch_size: torch.Size | None = None, auto_batch_size: bool = False, batch_dims: int | None = None)¶
Converts a pytree to a TensorDict instance.
This method is designed to keep the pytree nested structure as much as possible.
Additional non-tensor keys are added to keep track of each level’s identity, providing a built-in pytree-to-tensordict bijective transform API.
Accepted classes currently include lists, tuples, named tuples and dict.
Note
For dictionaries, non-NestedKey keys are registered separately as
NonTensorData
instances.Note
Tensor-castable types (such as int, float or np.ndarray) will be converted to torch.Tensor instances. Note that this transformation is surjective: transforming back the tensordict to a pytree will not recover the original types.
Examples
>>> # Create a pytree with tensor leaves, and one "weird"-looking dict key >>> class WeirdLookingClass: ... pass ... >>> weird_key = WeirdLookingClass() >>> # Make a pytree with tuple, lists, dict and namedtuple >>> pytree = ( ... [torch.randint(10, (3,)), torch.zeros(2)], ... { ... "tensor": torch.randn( ... 2, ... ), ... "td": TensorDict({"one": 1}), ... weird_key: torch.randint(10, (2,)), ... "list": [1, 2, 3], ... }, ... {"named_tuple": TensorDict({"two": torch.ones(1) * 2}).to_namedtuple()}, ... ) >>> # Build a TensorDict from that pytree >>> td = TensorDict.from_pytree(pytree) >>> # Recover the pytree >>> pytree_recon = td.to_pytree() >>> # Check that the leaves match >>> def check(v1, v2): >>> assert (v1 == v2).all() >>> >>> torch.utils._pytree.tree_map(check, pytree, pytree_recon) >>> assert weird_key in pytree_recon[1]
- from_remote_init(group: 'ProcessGroup' | None = None, device: torch.device | None = None) Self ¶
Creates a new tensordict instance initialized from remotely sent metadata.
This class method receives the metadata sent by init_remote, creates a new tensordict with matching shape and dtype, and then asynchronously receives the actual tensordict content.
- Parameters:
src (int) – The rank of the source process that sent the metadata.
group ("ProcessGroup", optional) – The process group to use for communication. Defaults to None.
device (torch.device, optional) – The device to use for tensor operations. Defaults to None.
- Returns:
A new tensordict instance initialized with the received metadata and content.
- Return type:
TensorDict
See also
The sending process should have called ~.init_remote to send the metadata and content.
- from_struct_array(*, auto_batch_size: bool = False, batch_dims: int | None = None, device: torch.device | None = None, batch_size: torch.Size | None = None) Self ¶
Converts a structured numpy array to a TensorDict.
The resulting TensorDict will share the same memory content as the numpy array (it is a zero-copy operation). Changing values of the structured numpy array in-place will affect the content of the TensorDict.
Note
This method performs a zero-copy operation, meaning that the resulting TensorDict will share the same memory content as the input numpy array. Therefore, changing values of the numpy array in-place will affect the content of the TensorDict.
- Parameters:
struct_array (np.ndarray) – The structured numpy array to be converted.
- Keyword Arguments:
auto_batch_size (bool, optional) – If
True
, the batch size will be computed automatically. Defaults toFalse
.batch_dims (int, optional) – If
auto_batch_size
isTrue
, defines how many dimensions the output tensordict should have. Defaults toNone
(full batch-size at each level).device (torch.device, optional) –
The device on which the TensorDict will be created. Defaults to
None
.Note
Changing the device (i.e., specifying any device other than
None
or"cpu"
) will transfer the data, resulting in a change to the memory location of the returned data.batch_size (torch.Size, optional) – The batch size of the TensorDict. Defaults to None.
- Returns:
A TensorDict representation of the input structured numpy array.
Examples
>>> x = np.array( ... [("Rex", 9, 81.0), ("Fido", 3, 27.0)], ... dtype=[("name", "U10"), ("age", "i4"), ("weight", "f4")], ... ) >>> td = TensorDict.from_struct_array(x) >>> x_recon = td.to_struct_array() >>> assert (x_recon == x).all() >>> assert x_recon.shape == x.shape >>> # Try modifying x age field and check effect on td >>> x["age"] += 1 >>> assert (td["age"] == np.array([10, 4])).all()
- classmethod from_tensordict(tensordict: TensorDictBase, non_tensordict: dict | None = None, safe: bool = True) Self ¶
Tensor class wrapper to instantiate a new tensor class object.
- Parameters:
tensordict (TensorDictBase) – Dictionary of tensor types
non_tensordict (dict) – Dictionary with non-tensor and nested tensor class objects
safe (bool) – Whether to raise an error if the tensordict is not a TensorDictBase instance
- from_tuple(*, auto_batch_size: bool = False, batch_dims: int | None = None, device: torch.device | None = None, batch_size: torch.Size | None = None)¶
Converts a tuple to a TensorDict.
- Parameters:
obj – The tuple instance to be converted.
- Keyword Arguments:
auto_batch_size (bool, optional) – If
True
, the batch size will be computed automatically. Defaults toFalse
.batch_dims (int, optional) – If auto_batch_size is
True
, defines how many dimensions the output tensordict should have. Defaults toNone
(full batch-size at each level).device (torch.device, optional) – The device on which the TensorDict will be created. Defaults to
None
.batch_size (torch.Size, optional) – The batch size of the TensorDict. Defaults to
None
.
- Returns:
A TensorDict representation of the input tuple.
Examples
>>> my_tuple = (1, 2, 3) >>> td = TensorDict.from_tuple(my_tuple) >>> print(td) TensorDict( fields={ 0: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), 1: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), 2: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
- fromkeys(value: Any = 0)¶
Creates a tensordict from a list of keys and a single value.
- Parameters:
keys (list of NestedKey) – An iterable specifying the keys of the new dictionary.
value (compatible type, optional) – The value for all keys. Defaults to
0
.
- property full_action_spec¶
The action spec of the tree.
This is an alias for Tree.specs[‘input_spec’, ‘full_action_spec’].
- property full_done_spec¶
The done spec of the tree.
This is an alias for Tree.specs[‘output_spec’, ‘full_done_spec’].
- property full_observation_spec¶
The observation spec of the tree.
This is an alias for Tree.specs[‘output_spec’, ‘full_observation_spec’].
- property full_reward_spec¶
The reward spec of the tree.
This is an alias for Tree.specs[‘output_spec’, ‘full_reward_spec’].
- property full_state_spec¶
The state spec of the tree.
This is an alias for Tree.specs[‘input_spec’, ‘full_state_spec’].
- fully_expanded(env: EnvBase) bool [source]¶
Returns True if the number of children is equal to the environment cardinality.
- get(key: NestedKey, *args, **kwargs)¶
Gets the value stored with the input key.
- Parameters:
key (str, tuple of str) – key to be queried. If tuple of str it is equivalent to chained calls of getattr.
default – default value if the key is not found in the tensorclass.
- Returns:
value stored with the input key
- get_vertex_by_hash(hash: int) Tree [source]¶
Goes through the tree and returns the node corresponding the given hash.
- get_vertex_by_id(id: int) Tree [source]¶
Goes through the tree and returns the node corresponding the given id.
- property is_terminal: bool | torch.Tensor¶
Returns True if the tree has no children nodes.
- lazy_stack(dim: int = 0, *, out=None, **kwargs)¶
Creates a lazy stack of tensordicts.
See
lazy_stack()
for details.
- load(*args, **kwargs) Self ¶
Loads a tensordict from disk.
This class method is a proxy to
load_memmap()
.
- load_(prefix: str | Path, *args, **kwargs)¶
Loads a tensordict from disk within the current tensordict.
This class method is a proxy to
load_memmap_()
.
- load_memmap(device: torch.device | None = None, non_blocking: bool = False, *, out: TensorDictBase | None = None, robust_key: bool | None = None) Self ¶
Loads a memory-mapped tensordict from disk.
- Parameters:
prefix (str or Path to folder) – the path to the folder where the saved tensordict should be fetched.
device (torch.device or equivalent, optional) – if provided, the data will be asynchronously cast to that device. Supports “meta” device, in which case the data isn’t loaded but a set of empty “meta” tensors are created. This is useful to get a sense of the total model size and structure without actually opening any file.
non_blocking (bool, optional) – if
True
, synchronize won’t be called after loading tensors on device. Defaults toFalse
.out (TensorDictBase, optional) – optional tensordict where the data should be written.
robust_key (bool, optional) – if
True
, expects robust key encoding was used when saving and decodes filenames accordingly. IfFalse
, uses legacy behavior. IfNone
(default), emits a deprecation warning and falls back to legacy behavior. Will default toTrue
in v0.12.
Examples
>>> from tensordict import TensorDict >>> td = TensorDict.fromkeys(["a", "b", "c", ("nested", "e")], 0) >>> td.memmap("./saved_td") >>> td_load = TensorDict.load_memmap("./saved_td") >>> assert (td == td_load).all()
This method also allows loading nested tensordicts.
Examples
>>> nested = TensorDict.load_memmap("./saved_td/nested") >>> assert nested["e"] == 0
A tensordict can also be loaded on “meta” device or, alternatively, as a fake tensor.
Examples
>>> import tempfile >>> td = TensorDict({"a": torch.zeros(()), "b": {"c": torch.zeros(())}}) >>> with tempfile.TemporaryDirectory() as path: ... td.save(path) ... td_load = TensorDict.load_memmap(path, device="meta") ... print("meta:", td_load) ... from torch._subclasses import FakeTensorMode ... with FakeTensorMode(): ... td_load = TensorDict.load_memmap(path) ... print("fake:", td_load) meta: TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=meta, dtype=torch.float32, is_shared=False), b: TensorDict( fields={ c: Tensor(shape=torch.Size([]), device=meta, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=meta, is_shared=False)}, batch_size=torch.Size([]), device=meta, is_shared=False) fake: TensorDict( fields={ a: FakeTensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), b: TensorDict( fields={ c: FakeTensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False)
- load_state_dict(state_dict: dict[str, Any], strict=True, assign=False, from_flatten=False)¶
Loads a state_dict attemptedly in-place on the destination tensorclass.
- classmethod make_node(data: TensorDictBase, *, device: torch.device | None = None, batch_size: torch.Size | None = None, specs: Composite | None = None) Tree [source]¶
Creates a new node given some data.
- max_length()[source]¶
Returns the maximum length of all valid paths in the tree.
The length of a path is defined as the number of nodes in the path. If the tree is empty, returns 0.
- Returns:
The maximum length of all valid paths in the tree.
- Return type:
int
- maybe_dense_stack(dim: int = 0, *, out=None, **kwargs)¶
Attempts to make a dense stack of tensordicts, and falls back on lazy stack when required..
See
maybe_dense_stack()
for details.
- memmap(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False, existsok: bool = True, robust_key: bool | None = None) Self ¶
Writes all tensors onto a corresponding memory-mapped Tensor in a new tensordict.
- Parameters:
prefix (str) – directory prefix where the memory-mapped tensors will be stored. The directory tree structure will mimic the tensordict’s.
copy_existing (bool) – If False (default), an exception will be raised if an entry in the tensordict is already a tensor stored on disk with an associated file, but is not saved in the correct location according to prefix. If
True
, any existing Tensor will be copied to the new location.
- Keyword Arguments:
num_threads (int, optional) – the number of threads used to write the memmap tensors. Defaults to 0.
return_early (bool, optional) – if
True
andnum_threads>0
, the method will return a future of the tensordict.share_non_tensor (bool, optional) – if
True
, the non-tensor data will be shared between the processes and writing operation (such as inplace update or set) on any of the workers within a single node will update the value on all other workers. If the number of non_tensor leaves is high (e.g., sharing large stacks of non-tensor data) this may result in OOM or similar errors. Defaults toFalse
.existsok (bool, optional) – if
False
, an exception will be raised if a tensor already exists in the same path. Defaults toTrue
.robust_key (bool, optional) – if
True
, uses robust key encoding that safely handles keys with path separators and special characters. IfFalse
, uses legacy behavior (keys used as-is). IfNone
(default), emits a deprecation warning and falls back to legacy behavior. Will default toTrue
in v0.12.
The TensorDict is then locked, meaning that any writing operations that isn’t in-place will throw an exception (eg, rename, set or remove an entry). Once the tensordict is unlocked, the memory-mapped attribute is turned to
False
, because cross-process identity is not guaranteed anymore.- Returns:
A new tensordict with the tensors stored on disk if
return_early=False
, otherwise aTensorDictFuture
instance.
Note
Serialising in this fashion might be slow with deeply nested tensordicts, so it is not recommended to call this method inside a training loop.
- memmap_(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False, existsok: bool = True, robust_key: bool | None = None) Self ¶
Writes all tensors onto a corresponding memory-mapped Tensor, in-place.
- Parameters:
prefix (str) – directory prefix where the memory-mapped tensors will be stored. The directory tree structure will mimic the tensordict’s.
copy_existing (bool) – If False (default), an exception will be raised if an entry in the tensordict is already a tensor stored on disk with an associated file, but is not saved in the correct location according to prefix. If
True
, any existing Tensor will be copied to the new location.
- Keyword Arguments:
num_threads (int, optional) – the number of threads used to write the memmap tensors. Defaults to 0.
return_early (bool, optional) – if
True
andnum_threads>0
, the method will return a future of the tensordict. The resulting tensordict can be queried using future.result().share_non_tensor (bool, optional) – if
True
, the non-tensor data will be shared between the processes and writing operation (such as inplace update or set) on any of the workers within a single node will update the value on all other workers. If the number of non-tensor leaves is high (e.g., sharing large stacks of non-tensor data) this may result in OOM or similar errors. Defaults toFalse
.existsok (bool, optional) – if
False
, an exception will be raised if a tensor already exists in the same path. Defaults toTrue
.robust_key (bool, optional) – if
True
, uses robust key encoding that safely handles keys with path separators and special characters. IfFalse
, uses legacy behavior (keys used as-is). IfNone
(default), emits a deprecation warning and falls back to legacy behavior. Will default toTrue
in v0.12.
The TensorDict is then locked, meaning that any writing operations that isn’t in-place will throw an exception (eg, rename, set or remove an entry). Once the tensordict is unlocked, the memory-mapped attribute is turned to
False
, because cross-process identity is not guaranteed anymore.- Returns:
self if
return_early=False
, otherwise aTensorDictFuture
instance.
Note
Serialising in this fashion might be slow with deeply nested tensordicts, so it is not recommended to call this method inside a training loop.
- memmap_like(prefix: str | None = None, copy_existing: bool = False, *, existsok: bool = True, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False, robust_key: bool | None = None) Self ¶
Creates a contentless Memory-mapped tensordict with the same shapes as the original one.
- Parameters:
prefix (str) – directory prefix where the memory-mapped tensors will be stored. The directory tree structure will mimic the tensordict’s.
copy_existing (bool) – If False (default), an exception will be raised if an entry in the tensordict is already a tensor stored on disk with an associated file, but is not saved in the correct location according to prefix. If
True
, any existing Tensor will be copied to the new location.
- Keyword Arguments:
num_threads (int, optional) – the number of threads used to write the memmap tensors. Defaults to 0.
return_early (bool, optional) – if
True
andnum_threads>0
, the method will return a future of the tensordict.share_non_tensor (bool, optional) – if
True
, the non-tensor data will be shared between the processes and writing operation (such as inplace update or set) on any of the workers within a single node will update the value on all other workers. If the number of non-tensor leaves is high (e.g., sharing large stacks of non-tensor data) this may result in OOM or similar errors. Defaults toFalse
.existsok (bool, optional) – if
False
, an exception will be raised if a tensor already exists in the same path. Defaults toTrue
.robust_key (bool, optional) – if
True
, uses robust key encoding that safely handles keys with path separators and special characters. IfFalse
, uses legacy behavior (keys used as-is). IfNone
(default), emits a deprecation warning and falls back to legacy behavior. Will default toTrue
in v0.12.
The TensorDict is then locked, meaning that any writing operations that isn’t in-place will throw an exception (eg, rename, set or remove an entry). Once the tensordict is unlocked, the memory-mapped attribute is turned to
False
, because cross-process identity is not guaranteed anymore.- Returns:
A new
TensorDict
instance with data stored as memory-mapped tensors ifreturn_early=False
, otherwise aTensorDictFuture
instance.
Note
This is the recommended method to write a set of large buffers on disk, as
memmap_()
will copy the information, which can be slow for large content.Examples
>>> td = TensorDict({ ... "a": torch.zeros((3, 64, 64), dtype=torch.uint8), ... "b": torch.zeros(1, dtype=torch.int64), ... }, batch_size=[]).expand(1_000_000) # expand does not allocate new memory >>> buffer = td.memmap_like("/path/to/dataset")
- memmap_refresh_()¶
Refreshes the content of the memory-mapped tensordict if it has a
saved_path
.This method will raise an exception if no path is associated with it.
- property node_observation: torch.Tensor | TensorDictBase¶
Returns the observation associated with this particular node.
This is the observation (or bag of observations) that defines the node before a branching occurs. If the node contains a
rollout()
attribute, the node observation is typically identical to the observation resulting from the last action undertaken, i.e.,node.rollout[..., -1]["next", "observation"]
.If more than one observation key is associated with the tree specs, a
TensorDict
instance is returned instead.For a more consistent representation, see
node_observations
.
- property node_observations: torch.Tensor | TensorDictBase¶
Returns the observations associated with this particular node in a TensorDict format.
This is the observation (or bag of observations) that defines the node before a branching occurs. If the node contains a
rollout()
attribute, the node observation is typically identical to the observation resulting from the last action undertaken, i.e.,node.rollout[..., -1]["next", "observation"]
.If more than one observation key is associated with the tree specs, a
TensorDict
instance is returned instead.For a more consistent representation, see
node_observations
.
- property num_children: int¶
Number of children of this node.
Equates to the number of elements in the
self.subtree
stack.
- num_vertices(*, count_repeat: bool = False) int [source]¶
Returns the number of unique vertices in the Tree.
- Keyword Arguments:
count_repeat (bool, optional) –
Determines whether to count repeated vertices.
If
False
, counts each unique vertex only once.If
True
, counts vertices multiple times if they appear in different paths.
Defaults to
False
.- Returns:
The number of unique vertices in the Tree.
- Return type:
int
- property parent: Tree | None¶
The parent of the node.
If the node has a parent and this object is still present in the python workspace, it will be returned by this property.
For re-branching trees, this property may return a stack of trees where every index of the stack corresponds to a different parent.
Note
the
parent
attribute will match in content but not in identity: the tensorclass object is recustructed using the same tensors (i.e., tensors that point to the same memory locations).- Returns:
A
Tree
containing the parent data orNone
if the parent data is out of scope or the node is the root.
- plot(backend: str = 'plotly', figure: str = 'tree', info: list[str] = None, make_labels: Callable[[Any, ...], Any] | None = None)[source]¶
Plots a visualization of the tree using the specified backend and figure type.
- Parameters:
backend – The plotting backend to use. Currently only supports ‘plotly’.
figure – The type of figure to plot. Can be either ‘tree’ or ‘box’.
info – A list of additional information to include in the plot (not currently used).
make_labels – An optional function to generate custom labels for the plot.
- Raises:
NotImplementedError – If an unsupported backend or figure type is specified.
- property prev_action: torch.Tensor | TensorDictBase | None¶
The action undertaken just before this node’s observation was generated.
- Returns:
a tensor, tensordict or None if the node has no parent.
See also
This will be equal to
branching_action
whenever the rollout data contains a single step.See also
All actions associated with a given node (or observation) in the tree
.
- rollout_from_path(path: tuple[int]) TensorDictBase | None [source]¶
Retrieves the rollout data along a given path in the tree.
The rollout data is concatenated along the last dimension (dim=-1) for each node in the path. If no rollout data is found along the path, returns
None
.- Parameters:
path – A tuple of integers representing the path in the tree.
- Returns:
The concatenated rollout data along the path, or None if no data is found.
- save(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False, robust_key: bool | None = None) Self ¶
Saves the tensordict to disk.
This function is a proxy to
memmap()
.
- property selected_actions: torch.Tensor | TensorDictBase | None¶
Returns a tensor containing all the selected actions branching out from this node.
- set(key: NestedKey, value: Any, inplace: bool = False, non_blocking: bool = False)¶
Sets a new key-value pair.
- Parameters:
key (str, tuple of str) – name of the key to be set. If tuple of str it is equivalent to chained calls of getattr followed by a final setattr.
value (Any) – value to be stored in the tensorclass
inplace (bool, optional) – if
True
, set will tentatively try to update the value in-place. IfFalse
or if the key isn’t present, the value will be simply written at its destination.
- Returns:
self
- stack(dim: int = 0, *, out=None)¶
Stacks tensordicts into a single tensordict along the given dimension.
This call is equivalent to calling
torch.stack()
but is compatible with torch.compile.
- state_dict(destination=None, prefix='', keep_vars=False, flatten=False) dict[str, Any] ¶
Returns a state_dict dictionary that can be used to save and load data from a tensorclass.
- to_string(node_format_fn=<function Tree.<lambda>>)[source]¶
Generates a string representation of the tree.
This function can pull out information from each of the nodes in a tree, so it can be useful for debugging. The nodes are listed line-by-line. Each line contains the path to the node, followed by the string representation of that node generated with :arg:`node_format_fn`. Each line is indented according to number of steps in the path required to get to the corresponding node.
- Parameters:
node_format_fn (Callable, optional) – User-defined function to generate a string for each node of the tree. The signature must be
(Tree) -> Any
, and the output must be convertible to a string. If this argument is not given, the generated string is the node’sTree.node_data
attribute converted to a dict.
Examples
>>> from torchrl.data import MCTSForest >>> from tensordict import TensorDict >>> forest = MCTSForest() >>> td_root = TensorDict({"observation": 0,}) >>> rollouts_data = [ ... # [(action, obs), ...] ... [(3, 123), (1, 456)], ... [(2, 359), (2, 3094)], ... [(3, 123), (9, 392), (6, 989), (20, 809), (21, 847)], ... [(1, 75)], ... [(3, 123), (0, 948)], ... [(2, 359), (2, 3094), (10, 68)], ... [(2, 359), (2, 3094), (11, 9045)], ... ] >>> for rollout_data in rollouts_data: ... td = td_root.clone().unsqueeze(0) ... for action, obs in rollout_data: ... td = td.update(TensorDict({ ... "action": [action], ... "next": TensorDict({"observation": [obs]}, [1]), ... }, [1])) ... forest.extend(td) ... td = td["next"].clone() ... >>> tree = forest.get_tree(td_root) >>> print(tree.to_string()) (0,) {'observation': tensor(123)} (0, 0) {'observation': tensor(456)} (0, 1) {'observation': tensor(847)} (0, 2) {'observation': tensor(948)} (1,) {'observation': tensor(3094)} (1, 0) {'observation': tensor(68)} (1, 1) {'observation': tensor(9045)} (2,) {'observation': tensor(75)}
- to_tensordict(*, retain_none: bool | None = None) TensorDict ¶
Convert the tensorclass into a regular TensorDict.
Makes a copy of all entries. Memmap and shared memory tensors are converted to regular tensors.
- Parameters:
retain_none (bool) – if
True
, theNone
values will be written in the tensordict. Otherwise they will be discrarded. Default:True
.- Returns:
A new TensorDict object containing the same values as the tensorclass.
- unbind(dim: int)¶
Returns a tuple of indexed tensorclass instances unbound along the indicated dimension.
Resulting tensorclass instances will share the storage of the initial tensorclass instance.
- valid_paths()[source]¶
Generates all valid paths in the tree.
A valid path is a sequence of child indices that starts at the root node and ends at a leaf node. Each path is represented as a tuple of integers, where each integer corresponds to the index of a child node.
- Yields:
tuple – A valid path in the tree.
- vertices(*, key_type: Literal['id', 'hash', 'path'] = 'hash') dict[int | tuple[int], Tree] [source]¶
Returns a map containing the vertices of the Tree.
- Keyword Arguments:
key_type (Literal["id", "hash", "path"], optional) –
Specifies the type of key to use for the vertices.
”id”: Use the vertex ID as the key.
”hash”: Use a hash of the vertex as the key.
- ”path”: Use the path to the vertex as the key. This may lead to a dictionary with a longer length than
when
"id"
or"hash"
are used as the same node may be part of multiple trajectories. Defaults to"hash"
.
Defaults to an empty string, which may imply a default behavior.
- Returns:
A dictionary mapping keys to Tree vertices.
- Return type:
Dict[int | Tuple[int], Tree]
- property visits: int | torch.Tensor¶
Returns the number of visits associated with this particular node.
This is an alias for the
count
attribute.