Tree¶
- class torchrl.data.Tree(count: 'int | torch.Tensor' = None, wins: 'int | torch.Tensor' = None, index: 'torch.Tensor | None' = None, hash: 'int | None' = None, node_id: 'int | None' = None, rollout: 'TensorDict | None' = None, node_data: 'TensorDict | None' = None, subtree: "'Tree'" = None, _parent: 'weakref.ref | List[weakref.ref] | None' = None, specs: 'Composite | None' = None, *, batch_size, device=None, names=None)[source]¶
- property branching_action: torch.Tensor | TensorDictBase | None¶
Returns the action that branched out to this particular node.
- Returns:
a tensor, tensordict or None if the node has no parent.
See also
This will be equal to
prev_actionwhenever the rollout data contains a single step.See also
All actions associated with a given node (or observation) in the tree.
- dumps(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) T¶
Saves the tensordict to disk.
This function is a proxy to
memmap().
- edges() List[Tuple[int, int]][source]¶
Retrieves a list of edges in the tree.
Each edge is represented as a tuple of two node IDs: the parent node ID and the child node ID. The tree is traversed using Breadth-First Search (BFS) to ensure all edges are visited.
- Returns:
A list of tuples, where each tuple contains a parent node ID and a child node ID.
- classmethod fields()¶
Return a tuple describing the fields of this dataclass.
Accepts a dataclass or an instance of one. Tuple elements are of type Field.
- classmethod from_tensordict(tensordict, non_tensordict=None, safe=True)¶
Tensor class wrapper to instantiate a new tensor class object.
- Parameters:
tensordict (TensorDict) – Dictionary of tensor types
non_tensordict (dict) – Dictionary with non-tensor and nested tensor class objects
- property full_action_spec¶
The action spec of the tree.
This is an alias for Tree.specs[‘input_spec’, ‘full_action_spec’].
- property full_done_spec¶
The done spec of the tree.
This is an alias for Tree.specs[‘output_spec’, ‘full_done_spec’].
- property full_observation_spec¶
The observation spec of the tree.
This is an alias for Tree.specs[‘output_spec’, ‘full_observation_spec’].
- property full_reward_spec¶
The reward spec of the tree.
This is an alias for Tree.specs[‘output_spec’, ‘full_reward_spec’].
- property full_state_spec¶
The state spec of the tree.
This is an alias for Tree.specs[‘input_spec’, ‘full_state_spec’].
- fully_expanded(env: EnvBase) bool[source]¶
Returns True if the number of children is equal to the environment cardinality.
- get(key: NestedKey, *args, **kwargs)¶
Gets the value stored with the input key.
- Parameters:
key (str, tuple of str) – key to be queried. If tuple of str it is equivalent to chained calls of getattr.
default – default value if the key is not found in the tensorclass.
- Returns:
value stored with the input key
- get_vertex_by_hash(hash: int) Tree[source]¶
Goes through the tree and returns the node corresponding the given hash.
- get_vertex_by_id(id: int) Tree[source]¶
Goes through the tree and returns the node corresponding the given id.
- property is_terminal: bool | torch.Tensor¶
Returns True if the tree has no children nodes.
- classmethod load(prefix: str | Path, *args, **kwargs) T¶
Loads a tensordict from disk.
This class method is a proxy to
load_memmap().
- load_(prefix: str | Path, *args, **kwargs)¶
Loads a tensordict from disk within the current tensordict.
This class method is a proxy to
load_memmap_().
- classmethod load_memmap(prefix: str | Path, device: torch.device | None = None, non_blocking: bool = False, *, out: TensorDictBase | None = None) T¶
Loads a memory-mapped tensordict from disk.
- Parameters:
prefix (str or Path to folder) – the path to the folder where the saved tensordict should be fetched.
device (torch.device or equivalent, optional) – if provided, the data will be asynchronously cast to that device. Supports “meta” device, in which case the data isn’t loaded but a set of empty “meta” tensors are created. This is useful to get a sense of the total model size and structure without actually opening any file.
non_blocking (bool, optional) – if
True, synchronize won’t be called after loading tensors on device. Defaults toFalse.out (TensorDictBase, optional) – optional tensordict where the data should be written.
Examples
>>> from tensordict import TensorDict >>> td = TensorDict.fromkeys(["a", "b", "c", ("nested", "e")], 0) >>> td.memmap("./saved_td") >>> td_load = TensorDict.load_memmap("./saved_td") >>> assert (td == td_load).all()
This method also allows loading nested tensordicts.
Examples
>>> nested = TensorDict.load_memmap("./saved_td/nested") >>> assert nested["e"] == 0
A tensordict can also be loaded on “meta” device or, alternatively, as a fake tensor.
Examples
>>> import tempfile >>> td = TensorDict({"a": torch.zeros(()), "b": {"c": torch.zeros(())}}) >>> with tempfile.TemporaryDirectory() as path: ... td.save(path) ... td_load = TensorDict.load_memmap(path, device="meta") ... print("meta:", td_load) ... from torch._subclasses import FakeTensorMode ... with FakeTensorMode(): ... td_load = TensorDict.load_memmap(path) ... print("fake:", td_load) meta: TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=meta, dtype=torch.float32, is_shared=False), b: TensorDict( fields={ c: Tensor(shape=torch.Size([]), device=meta, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=meta, is_shared=False)}, batch_size=torch.Size([]), device=meta, is_shared=False) fake: TensorDict( fields={ a: FakeTensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), b: TensorDict( fields={ c: FakeTensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False)
- load_state_dict(state_dict: dict[str, Any], strict=True, assign=False, from_flatten=False)¶
Loads a state_dict attemptedly in-place on the destination tensorclass.
- classmethod make_node(data: TensorDictBase, *, device: torch.device | None = None, batch_size: torch.Size | None = None, specs: Composite | None = None) Tree[source]¶
Creates a new node given some data.
- max_length()[source]¶
Returns the maximum length of all valid paths in the tree.
The length of a path is defined as the number of nodes in the path. If the tree is empty, returns 0.
- Returns:
The maximum length of all valid paths in the tree.
- Return type:
int
- memmap(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False, existsok: bool = True) T¶
Writes all tensors onto a corresponding memory-mapped Tensor in a new tensordict.
- Parameters:
prefix (str) – directory prefix where the memory-mapped tensors will be stored. The directory tree structure will mimic the tensordict’s.
copy_existing (bool) – If False (default), an exception will be raised if an entry in the tensordict is already a tensor stored on disk with an associated file, but is not saved in the correct location according to prefix. If
True, any existing Tensor will be copied to the new location.
- Keyword Arguments:
num_threads (int, optional) – the number of threads used to write the memmap tensors. Defaults to 0.
return_early (bool, optional) – if
Trueandnum_threads>0, the method will return a future of the tensordict.share_non_tensor (bool, optional) – if
True, the non-tensor data will be shared between the processes and writing operation (such as inplace update or set) on any of the workers within a single node will update the value on all other workers. If the number of non-tensor leaves is high (e.g., sharing large stacks of non-tensor data) this may result in OOM or similar errors. Defaults toFalse.existsok (bool, optional) – if
False, an exception will be raised if a tensor already exists in the same path. Defaults toTrue.
The TensorDict is then locked, meaning that any writing operations that isn’t in-place will throw an exception (eg, rename, set or remove an entry). Once the tensordict is unlocked, the memory-mapped attribute is turned to
False, because cross-process identity is not guaranteed anymore.- Returns:
A new tensordict with the tensors stored on disk if
return_early=False, otherwise aTensorDictFutureinstance.
Note
Serialising in this fashion might be slow with deeply nested tensordicts, so it is not recommended to call this method inside a training loop.
- memmap_(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False, existsok: bool = True) T¶
Writes all tensors onto a corresponding memory-mapped Tensor, in-place.
- Parameters:
prefix (str) – directory prefix where the memory-mapped tensors will be stored. The directory tree structure will mimic the tensordict’s.
copy_existing (bool) – If False (default), an exception will be raised if an entry in the tensordict is already a tensor stored on disk with an associated file, but is not saved in the correct location according to prefix. If
True, any existing Tensor will be copied to the new location.
- Keyword Arguments:
num_threads (int, optional) – the number of threads used to write the memmap tensors. Defaults to 0.
return_early (bool, optional) – if
Trueandnum_threads>0, the method will return a future of the tensordict. The resulting tensordict can be queried using future.result().share_non_tensor (bool, optional) – if
True, the non-tensor data will be shared between the processes and writing operation (such as inplace update or set) on any of the workers within a single node will update the value on all other workers. If the number of non-tensor leaves is high (e.g., sharing large stacks of non-tensor data) this may result in OOM or similar errors. Defaults toFalse.existsok (bool, optional) – if
False, an exception will be raised if a tensor already exists in the same path. Defaults toTrue.
The TensorDict is then locked, meaning that any writing operations that isn’t in-place will throw an exception (eg, rename, set or remove an entry). Once the tensordict is unlocked, the memory-mapped attribute is turned to
False, because cross-process identity is not guaranteed anymore.- Returns:
self if
return_early=False, otherwise aTensorDictFutureinstance.
Note
Serialising in this fashion might be slow with deeply nested tensordicts, so it is not recommended to call this method inside a training loop.
- memmap_like(prefix: str | None = None, copy_existing: bool = False, *, existsok: bool = True, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) T¶
Creates a contentless Memory-mapped tensordict with the same shapes as the original one.
- Parameters:
prefix (str) – directory prefix where the memory-mapped tensors will be stored. The directory tree structure will mimic the tensordict’s.
copy_existing (bool) – If False (default), an exception will be raised if an entry in the tensordict is already a tensor stored on disk with an associated file, but is not saved in the correct location according to prefix. If
True, any existing Tensor will be copied to the new location.
- Keyword Arguments:
num_threads (int, optional) – the number of threads used to write the memmap tensors. Defaults to 0.
return_early (bool, optional) – if
Trueandnum_threads>0, the method will return a future of the tensordict.share_non_tensor (bool, optional) – if
True, the non-tensor data will be shared between the processes and writing operation (such as inplace update or set) on any of the workers within a single node will update the value on all other workers. If the number of non-tensor leaves is high (e.g., sharing large stacks of non-tensor data) this may result in OOM or similar errors. Defaults toFalse.existsok (bool, optional) – if
False, an exception will be raised if a tensor already exists in the same path. Defaults toTrue.
The TensorDict is then locked, meaning that any writing operations that isn’t in-place will throw an exception (eg, rename, set or remove an entry). Once the tensordict is unlocked, the memory-mapped attribute is turned to
False, because cross-process identity is not guaranteed anymore.- Returns:
A new
TensorDictinstance with data stored as memory-mapped tensors ifreturn_early=False, otherwise aTensorDictFutureinstance.
Note
This is the recommended method to write a set of large buffers on disk, as
memmap_()will copy the information, which can be slow for large content.Examples
>>> td = TensorDict({ ... "a": torch.zeros((3, 64, 64), dtype=torch.uint8), ... "b": torch.zeros(1, dtype=torch.int64), ... }, batch_size=[]).expand(1_000_000) # expand does not allocate new memory >>> buffer = td.memmap_like("/path/to/dataset")
- memmap_refresh_()¶
Refreshes the content of the memory-mapped tensordict if it has a
saved_path.This method will raise an exception if no path is associated with it.
- property node_observation: torch.Tensor | TensorDictBase¶
Returns the observation associated with this particular node.
This is the observation (or bag of observations) that defines the node before a branching occurs. If the node contains a
rollout()attribute, the node observation is typically identical to the observation resulting from the last action undertaken, i.e.,node.rollout[..., -1]["next", "observation"].If more than one observation key is associated with the tree specs, a
TensorDictinstance is returned instead.For a more consistent representation, see
node_observations.
- property node_observations: torch.Tensor | TensorDictBase¶
Returns the observations associated with this particular node in a TensorDict format.
This is the observation (or bag of observations) that defines the node before a branching occurs. If the node contains a
rollout()attribute, the node observation is typically identical to the observation resulting from the last action undertaken, i.e.,node.rollout[..., -1]["next", "observation"].If more than one observation key is associated with the tree specs, a
TensorDictinstance is returned instead.For a more consistent representation, see
node_observations.
- property num_children: int¶
Number of children of this node.
Equates to the number of elements in the
self.subtreestack.
- num_vertices(*, count_repeat: bool = False) int[source]¶
Returns the number of unique vertices in the Tree.
- Keyword Arguments:
count_repeat (bool, optional) –
Determines whether to count repeated vertices.
If
False, counts each unique vertex only once.If
True, counts vertices multiple times if they appear in different paths.
Defaults to
False.- Returns:
The number of unique vertices in the Tree.
- Return type:
int
- property parent: Tree | None¶
The parent of the node.
If the node has a parent and this object is still present in the python workspace, it will be returned by this property.
For re-branching trees, this property may return a stack of trees where every index of the stack corresponds to a different parent.
Note
the
parentattribute will match in content but not in identity: the tensorclass object is recustructed using the same tensors (i.e., tensors that point to the same memory locations).- Returns:
A
Treecontaining the parent data orNoneif the parent data is out of scope or the node is the root.
- plot(backend: str = 'plotly', figure: str = 'tree', info: List[str] = None, make_labels: Callable[[Any, ...], Any] | None = None)[source]¶
Plots a visualization of the tree using the specified backend and figure type.
- Parameters:
backend – The plotting backend to use. Currently only supports ‘plotly’.
figure – The type of figure to plot. Can be either ‘tree’ or ‘box’.
info – A list of additional information to include in the plot (not currently used).
make_labels – An optional function to generate custom labels for the plot.
- Raises:
NotImplementedError – If an unsupported backend or figure type is specified.
- property prev_action: torch.Tensor | TensorDictBase | None¶
The action undertaken just before this node’s observation was generated.
- Returns:
a tensor, tensordict or None if the node has no parent.
See also
This will be equal to
branching_actionwhenever the rollout data contains a single step.See also
All actions associated with a given node (or observation) in the tree.
- rollout_from_path(path: Tuple[int]) TensorDictBase | None[source]¶
Retrieves the rollout data along a given path in the tree.
The rollout data is concatenated along the last dimension (dim=-1) for each node in the path. If no rollout data is found along the path, returns
None.- Parameters:
path – A tuple of integers representing the path in the tree.
- Returns:
The concatenated rollout data along the path, or None if no data is found.
- save(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) T¶
Saves the tensordict to disk.
This function is a proxy to
memmap().
- property selected_actions: torch.Tensor | TensorDictBase | None¶
Returns a tensor containing all the selected actions branching out from this node.
- set(key: NestedKey, value: Any, inplace: bool = False, non_blocking: bool = False)¶
Sets a new key-value pair.
- Parameters:
key (str, tuple of str) – name of the key to be set. If tuple of str it is equivalent to chained calls of getattr followed by a final setattr.
value (Any) – value to be stored in the tensorclass
inplace (bool, optional) – if
True, set will tentatively try to update the value in-place. IfFalseor if the key isn’t present, the value will be simply written at its destination.
- Returns:
self
- state_dict(destination=None, prefix='', keep_vars=False, flatten=False) dict[str, Any]¶
Returns a state_dict dictionary that can be used to save and load data from a tensorclass.
- to_tensordict(*, retain_none: bool | None = None) TensorDict¶
Convert the tensorclass into a regular TensorDict.
Makes a copy of all entries. Memmap and shared memory tensors are converted to regular tensors.
- Parameters:
retain_none (bool) –
if
True, theNonevalues will be written in the tensordict. Otherwise they will be discrarded. Default:True.Note
from v0.8, the default value will be switched to
False.- Returns:
A new TensorDict object containing the same values as the tensorclass.
- unbind(dim: int)¶
Returns a tuple of indexed tensorclass instances unbound along the indicated dimension.
Resulting tensorclass instances will share the storage of the initial tensorclass instance.
- valid_paths()[source]¶
Generates all valid paths in the tree.
A valid path is a sequence of child indices that starts at the root node and ends at a leaf node. Each path is represented as a tuple of integers, where each integer corresponds to the index of a child node.
- Yields:
tuple – A valid path in the tree.
- vertices(*, key_type: Literal['id', 'hash', 'path'] = 'hash') Dict[int | Tuple[int], Tree][source]¶
Returns a map containing the vertices of the Tree.
- Keyword Arguments:
key_type (Literal["id", "hash", "path"], optional) –
Specifies the type of key to use for the vertices.
”id”: Use the vertex ID as the key.
”hash”: Use a hash of the vertex as the key.
- ”path”: Use the path to the vertex as the key. This may lead to a dictionary with a longer length than
when
"id"or"hash"are used as the same node may be part of multiple trajectories. Defaults to"hash".
Defaults to an empty string, which may imply a default behavior.
- Returns:
A dictionary mapping keys to Tree vertices.
- Return type:
Dict[int | Tuple[int], Tree]
- property visits: int | torch.Tensor¶
Returns the number of visits associated with this particular node.
This is an alias for the
countattribute.