SliceSampler¶
- class torchrl.data.replay_buffers.SliceSampler(*, num_slices: int | None = None, slice_len: int | None = None, end_key: tensordict._nestedkey.NestedKey | None = None, traj_key: tensordict._nestedkey.NestedKey | None = None, ends: torch.Tensor | None = None, trajectories: torch.Tensor | None = None, cache_values: bool = False, truncated_key: tensordict._nestedkey.NestedKey | None = ('next', 'truncated'), strict_length: bool = True, compile: bool | dict = False, span: bool | int | tuple[bool | int, bool | int] = False, use_gpu: torch.device | bool = False)[source]¶
Samples slices of data along the first dimension, given start and stop signals.
This class samples sub-trajectories with replacement. For a version without replacement, see
SliceSamplerWithoutReplacement.Note
SliceSampler can be slow to retrieve the trajectory indices. To accelerate its execution, prefer using end_key over traj_key, and consider the following keyword arguments:
compile,cache_valuesanduse_gpu.- Keyword Arguments:
num_slices (int) – the number of slices to be sampled. The batch-size must be greater or equal to the
num_slicesargument. Exclusive withslice_len.slice_len (int) – the length of the slices to be sampled. The batch-size must be greater or equal to the
slice_lenargument and divisible by it. Exclusive withnum_slices.end_key (NestedKey, optional) – the key indicating the end of a trajectory (or episode). Defaults to
("next", "done").traj_key (NestedKey, optional) – the key indicating the trajectories. Defaults to
"episode"(commonly used across datasets in TorchRL).ends (torch.Tensor, optional) – a 1d boolean tensor containing the end of run signals. To be used whenever the
end_keyortraj_keyis expensive to get, or when this signal is readily available. Must be used withcache_values=Trueand cannot be used in conjunction withend_keyortraj_key. If provided, it is assumed that the storage is at capacity and that if the last element of theendstensor isFalse, the same trajectory spans across end and beginning.trajectories (torch.Tensor, optional) – a 1d integer tensor containing the run ids. To be used whenever the
end_keyortraj_keyis expensive to get, or when this signal is readily available. Must be used withcache_values=Trueand cannot be used in conjunction withend_keyortraj_key. If provided, it is assumed that the storage is at capacity and that if the last element of the trajectory tensor is identical to the first, the same trajectory spans across end and beginning.cache_values (bool, optional) –
to be used with static datasets. Will cache the start and end signal of the trajectory. This can be safely used even if the trajectory indices change during calls to
extendas this operation will erase the cache.Warning
cache_values=Truewill not work if the sampler is used with a storage that is extended by another buffer. For instance:>>> buffer0 = ReplayBuffer(storage=storage, ... sampler=SliceSampler(num_slices=8, cache_values=True), ... writer=ImmutableWriter()) >>> buffer1 = ReplayBuffer(storage=storage, ... sampler=other_sampler) >>> # Wrong! Does not erase the buffer from the sampler of buffer0 >>> buffer1.extend(data)
Warning
cache_values=Truewill not work as expected if the buffer is shared between processes and one process is responsible for writing and one process for sampling, as erasing the cache can only be done locally.truncated_key (NestedKey, optional) – If not
None, this argument indicates where a truncated signal should be written in the output data. This is used to indicate to value estimators where the provided trajectory breaks. Defaults to("next", "truncated"). This feature only works withTensorDictReplayBufferinstances (otherwise the truncated key is returned in the info dictionary returned by thesample()method).strict_length (bool, optional) – if
False, trajectories of length shorter than slice_len (or batch_size // num_slices) will be allowed to appear in the batch. IfTrue, trajectories shorted than required will be filtered out. Be mindful that this can result in effective batch_size shorter than the one asked for! Trajectories can be split usingsplit_trajectories(). Defaults toTrue.compile (bool or dict of kwargs, optional) – if
True, the bottleneck of thesample()method will be compiled withcompile(). Keyword arguments can also be passed to torch.compile with this arg. Defaults toFalse.span (bool, int, Tuple[bool | int, bool | int], optional) – if provided, the sampled trajectory will span across the left and/or the right. This means that possibly fewer elements will be provided than what was required. A boolean value means that at least one element will be sampled per trajectory. An integer i means that at least slice_len - i samples will be gathered for each sampled trajectory. Using tuples allows a fine grained control over the span on the left (beginning of the stored trajectory) and on the right (end of the stored trajectory).
use_gpu (bool or torch.device) – if
True(or is a device is passed), an accelerator will be used to retrieve the indices of the trajectory starts. This can significantly accelerate the sampling when the buffer content is large. Defaults toFalse.
Note
To recover the trajectory splits in the storage,
SliceSamplerwill first attempt to find thetraj_keyentry in the storage. If it cannot be found, theend_keywill be used to reconstruct the episodes.Note
When using strict_length=False, it is recommended to use
split_trajectories()to split the sampled trajectories. However, if two samples from the same episode are placed next to each other, this may produce incorrect results. To avoid this issue, consider one of these solutions:using a
TensorDictReplayBufferinstance with the slice sampler>>> import torch >>> from tensordict import TensorDict >>> from torchrl.collectors.utils import split_trajectories >>> from torchrl.data import TensorDictReplayBuffer, ReplayBuffer, LazyTensorStorage, SliceSampler, SliceSamplerWithoutReplacement >>> >>> rb = TensorDictReplayBuffer(storage=LazyTensorStorage(max_size=1000), ... sampler=SliceSampler( ... slice_len=5, traj_key="episode",strict_length=False, ... )) ... >>> ep_1 = TensorDict( ... {"obs": torch.arange(100), ... "episode": torch.zeros(100),}, ... batch_size=[100] ... ) >>> ep_2 = TensorDict( ... {"obs": torch.arange(4), ... "episode": torch.ones(4),}, ... batch_size=[4] ... ) >>> rb.extend(ep_1) >>> rb.extend(ep_2) >>> >>> s = rb.sample(50) >>> print(s) TensorDict( fields={ episode: Tensor(shape=torch.Size([46]), device=cpu, dtype=torch.float32, is_shared=False), index: Tensor(shape=torch.Size([46, 1]), device=cpu, dtype=torch.int64, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([46, 1]), device=cpu, dtype=torch.bool, is_shared=False), terminated: Tensor(shape=torch.Size([46, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([46, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([46]), device=cpu, is_shared=False), obs: Tensor(shape=torch.Size([46]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([46]), device=cpu, is_shared=False) >>> t = split_trajectories(s, done_key="truncated") >>> print(t["obs"]) tensor([[73, 74, 75, 76, 77], [ 0, 1, 2, 3, 0], [ 0, 1, 2, 3, 0], [41, 42, 43, 44, 45], [ 0, 1, 2, 3, 0], [67, 68, 69, 70, 71], [27, 28, 29, 30, 31], [80, 81, 82, 83, 84], [17, 18, 19, 20, 21], [ 0, 1, 2, 3, 0]]) >>> print(t["episode"]) tensor([[0., 0., 0., 0., 0.], [1., 1., 1., 1., 0.], [1., 1., 1., 1., 0.], [0., 0., 0., 0., 0.], [1., 1., 1., 1., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [1., 1., 1., 1., 0.]])
using a
SliceSamplerWithoutReplacement>>> import torch >>> from tensordict import TensorDict >>> from torchrl.collectors.utils import split_trajectories >>> from torchrl.data import ReplayBuffer, LazyTensorStorage, SliceSampler, SliceSamplerWithoutReplacement >>> >>> rb = ReplayBuffer(storage=LazyTensorStorage(max_size=1000), ... sampler=SliceSamplerWithoutReplacement( ... slice_len=5, traj_key="episode",strict_length=False ... )) ... >>> ep_1 = TensorDict( ... {"obs": torch.arange(100), ... "episode": torch.zeros(100),}, ... batch_size=[100] ... ) >>> ep_2 = TensorDict( ... {"obs": torch.arange(4), ... "episode": torch.ones(4),}, ... batch_size=[4] ... ) >>> rb.extend(ep_1) >>> rb.extend(ep_2) >>> >>> s = rb.sample(50) >>> t = split_trajectories(s, trajectory_key="episode") >>> print(t["obs"]) tensor([[75, 76, 77, 78, 79], [ 0, 1, 2, 3, 0]]) >>> print(t["episode"]) tensor([[0., 0., 0., 0., 0.], [1., 1., 1., 1., 0.]])
Examples
>>> import torch >>> from tensordict import TensorDict >>> from torchrl.data.replay_buffers import LazyMemmapStorage, TensorDictReplayBuffer >>> from torchrl.data.replay_buffers.samplers import SliceSampler >>> torch.manual_seed(0) >>> rb = TensorDictReplayBuffer( ... storage=LazyMemmapStorage(1_000_000), ... sampler=SliceSampler(cache_values=True, num_slices=10), ... batch_size=320, ... ) >>> episode = torch.zeros(1000, dtype=torch.int) >>> episode[:300] = 1 >>> episode[300:550] = 2 >>> episode[550:700] = 3 >>> episode[700:] = 4 >>> data = TensorDict( ... { ... "episode": episode, ... "obs": torch.randn((3, 4, 5)).expand(1000, 3, 4, 5), ... "act": torch.randn((20,)).expand(1000, 20), ... "other": torch.randn((20, 50)).expand(1000, 20, 50), ... }, [1000] ... ) >>> rb.extend(data) >>> sample = rb.sample() >>> print("sample:", sample) >>> print("episodes", sample.get("episode").unique()) episodes tensor([1, 2, 3, 4], dtype=torch.int32)
SliceSampleris default-compatible with most of TorchRL’s datasets:Examples
>>> import torch >>> >>> from torchrl.data.datasets import RobosetExperienceReplay >>> from torchrl.data import SliceSampler >>> >>> torch.manual_seed(0) >>> num_slices = 10 >>> dataid = list(RobosetExperienceReplay.available_datasets)[0] >>> data = RobosetExperienceReplay(dataid, batch_size=320, sampler=SliceSampler(num_slices=num_slices)) >>> for batch in data: ... batch = batch.reshape(num_slices, -1) ... break >>> print("check that each batch only has one episode:", batch["episode"].unique(dim=1)) check that each batch only has one episode: tensor([[19], [14], [ 8], [10], [13], [ 4], [ 2], [ 3], [22], [ 8]])