dense_stack_tds¶
- class tensordict.dense_stack_tds(td_list: Union[Sequence[TensorDictBase], LazyStackedTensorDict], dim: Optional[int] = None)¶
Densely stack a list of
TensorDictBaseobjects (or aLazyStackedTensorDict) given that they have the same structure.This function is called with a list of
TensorDictBase(either passed directly or obtrained from aLazyStackedTensorDict). Instead of callingtorch.stack(td_list), which would return aLazyStackedTensorDict, this function expands the first element of the input list and stacks the input list onto that element. This works only when all the elements of the input list have the same structure. TheTensorDictBasereturned will have the same type of the elements of the input list.This function is useful when some of the
TensorDictBaseobjects that need to be stacked areLazyStackedTensorDictor haveLazyStackedTensorDictamong entries (or nested entries). In those cases, callingtorch.stack(td_list).to_tensordict()is infeasible. Thus, this function provides an alternative for densely stacking the list provided.- Parameters:
td_list (List of TensorDictBase or LazyStackedTensorDict) – the tds to stack.
dim (int, optional) – the dimension to stack them. If td_list is a LazyStackedTensorDict, it will be retrieved automatically.
Examples
>>> import torch >>> from tensordict import TensorDict >>> from tensordict import dense_stack_tds >>> from tensordict.tensordict import assert_allclose_td >>> td0 = TensorDict({"a": torch.zeros(3)},[]) >>> td1 = TensorDict({"a": torch.zeros(4), "b": torch.zeros(2)},[]) >>> td_lazy = torch.stack([td0, td1], dim=0) >>> td_container = TensorDict({"lazy": td_lazy}, []) >>> td_container_clone = td_container.clone() >>> td_stack = torch.stack([td_container, td_container_clone], dim=0) >>> td_stack LazyStackedTensorDict( fields={ lazy: LazyStackedTensorDict( fields={ a: Tensor(shape=torch.Size([2, 2, -1]), device=cpu, dtype=torch.float32, is_shared=False)}, exclusive_fields={ }, batch_size=torch.Size([2, 2]), device=None, is_shared=False, stack_dim=0)}, exclusive_fields={ }, batch_size=torch.Size([2]), device=None, is_shared=False, stack_dim=0) >>> td_stack = dense_stack_tds(td_stack) # Automatically use the LazyStackedTensorDict stack_dim TensorDict( fields={ lazy: LazyStackedTensorDict( fields={ a: Tensor(shape=torch.Size([2, 2, -1]), device=cpu, dtype=torch.float32, is_shared=False)}, exclusive_fields={ 1 -> b: Tensor(shape=torch.Size([2, 2]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([2, 2]), device=None, is_shared=False, stack_dim=1)}, batch_size=torch.Size([2]), device=None, is_shared=False) # Note that # (1) td_stack is now a TensorDict # (2) this has pushed the stack_dim of "lazy" (0 -> 1) # (3) this has revealed the exclusive keys. >>> assert_allclose_td(td_stack, dense_stack_tds([td_container, td_container_clone], dim=0)) # This shows it is the same to pass a list or a LazyStackedTensorDict