SliceSampler¶
- class torchrl.data.replay_buffers.SliceSampler(*args, **kwargs)[source]¶
Samples slices of data along the first dimension, given start and stop signals.
This class samples sub-trajectories with replacement. For a version without replacement, see
SliceSamplerWithoutReplacement. Equivalently,SliceSampler(replacement=False, ...)dispatches toSliceSamplerWithoutReplacementand forwards the remaining keyword arguments (includingdrop_lastandshuffle).Note
SliceSampler can be slow to retrieve the trajectory indices. To accelerate its execution, prefer using end_key over traj_key, and consider the following keyword arguments:
compile,cache_valuesanduse_gpu.- Keyword Arguments:
replacement (bool, optional) – if
False, the call is dispatched toSliceSamplerWithoutReplacement(which accepts the same keyword arguments as well asdrop_lastandshuffle). Defaults toTrue.num_slices (int) – the number of slices to be sampled. The batch-size must be greater or equal to the
num_slicesargument. Exclusive withslice_len.slice_len (int) – the length of the slices to be sampled. The batch-size must be greater or equal to the
slice_lenargument and divisible by it. Exclusive withnum_slices.end_key (NestedKey, optional) – the key indicating the end of a trajectory (or episode). Defaults to
("next", "done").traj_key (NestedKey, optional) – the key indicating the trajectories. Defaults to
"episode"(commonly used across datasets in TorchRL).ends (torch.Tensor, optional) – a 1d boolean tensor containing the end of run signals. To be used whenever the
end_keyortraj_keyis expensive to get, or when this signal is readily available. Must be used withcache_values=Trueand cannot be used in conjunction withend_keyortraj_key. If provided, it is assumed that the storage is at capacity and that if the last element of theendstensor isFalse, the same trajectory spans across end and beginning.trajectories (torch.Tensor, optional) – a 1d integer tensor containing the run ids. To be used whenever the
end_keyortraj_keyis expensive to get, or when this signal is readily available. Must be used withcache_values=Trueand cannot be used in conjunction withend_keyortraj_key. If provided, it is assumed that the storage is at capacity and that if the last element of the trajectory tensor is identical to the first, the same trajectory spans across end and beginning.cache_values (bool, optional) –
to be used with static datasets. Will cache the start and end signal of the trajectory. This can be safely used even if the trajectory indices change during calls to
extendas this operation will erase the cache.Warning
cache_values=Truewill not work if the sampler is used with a storage that is extended by another buffer. For instance:>>> buffer0 = ReplayBuffer(storage=storage, ... sampler=SliceSampler(num_slices=8, cache_values=True), ... writer=ImmutableWriter()) >>> buffer1 = ReplayBuffer(storage=storage, ... sampler=other_sampler) >>> # Wrong! Does not erase the buffer from the sampler of buffer0 >>> buffer1.extend(data)
Warning
cache_values=Truewill not work as expected if the buffer is shared between processes and one process is responsible for writing and one process for sampling, as erasing the cache can only be done locally.truncated_key (NestedKey, optional) – If not
None, this argument indicates where a truncated signal should be written in the output data. This is used to indicate to value estimators where the provided trajectory breaks. Defaults to("next", "truncated"). This feature only works withTensorDictReplayBufferinstances (otherwise the truncated key is returned in the info dictionary returned by thesample()method).strict_length (bool, optional) – if
False, trajectories of length shorter than slice_len (or batch_size // num_slices) will be allowed to appear in the batch. IfTrue, trajectories shorted than required will be filtered out. Be mindful that this can result in effective batch_size shorter than the one asked for! Trajectories can be split usingsplit_trajectories(). Defaults toTrue.pad_output (bool, optional) – discouraged. Prefer the default (``False``). When
True(andstrict_length=False), short trajectories are padded by duplicating their last real timestep up toslice_lenso the output’sB * Tis a fixed product. The output is still a 1D batch of shape[B * T]— the sample is not reshaped to[B, T]. A 1D boolean mask of shape[B * T]is written to("collector", "mask")flagging real (True) vs duplicated-last-step (False) positions. TorchRL’s primitives (recurrent modules underset_recurrent_mode(), mask-aware loss modules,split_trajectories, etc.) are all designed to consume concatenated variable-length slices directly via theis_init/truncatedmarkers the sampler already emits, so padding is a niche escape hatch for downstream code that genuinely cannot accept a ragged batch (e.g. a custom op that requires a fixed time dimension before a manual reshape). Combiningpad_output=Truewithstrict_length=TrueraisesValueError. Defaults toFalse.compile (bool or dict of kwargs, optional) – if
True, the bottleneck of thesample()method will be compiled withcompile(). Keyword arguments can also be passed to torch.compile with this arg. Defaults toFalse.span (bool, int, Tuple[bool | int, bool | int], optional) – if provided, the sampled trajectory will span across the left and/or the right. This means that possibly fewer elements will be provided than what was required. A boolean value means that at least one element will be sampled per trajectory. An integer i means that at least slice_len - i samples will be gathered for each sampled trajectory. Using tuples allows a fine grained control over the span on the left (beginning of the stored trajectory) and on the right (end of the stored trajectory).
use_gpu (bool or torch.device) – if
True(or is a device is passed), an accelerator will be used to retrieve the indices of the trajectory starts. This can significantly accelerate the sampling when the buffer content is large. Defaults toFalse.
Note
To recover the trajectory splits in the storage,
SliceSamplerwill first attempt to find thetraj_keyentry in the storage. If it cannot be found, theend_keywill be used to reconstruct the episodes.Note
When using a multi-process collector (
MultiSyncCollectororMultiAsyncCollector) with a shared replay buffer, adjacent transitions in the buffer may come from different workers and different episodes. ASliceSamplerthat relies onend_keycan then sample slices that straddle unrelated trajectories.To avoid this, either:
set
trajs_per_batchon the collector so that only complete trajectories (each ending withdone=True) are written to the buffer (usendim=1on the storage —ndim >= 2is incompatible with the variable-length flat sequences thattrajs_per_batchproduces), orset
set_truncated=Trueon the collector so that every batch boundary carries adonesignal (note: this introduces artificial truncations that value estimators must account for).
Note
When using strict_length=False, it is recommended to use
split_trajectories()to split the sampled trajectories. However, if two samples from the same episode are placed next to each other, this may produce incorrect results. To avoid this issue, consider one of these solutions:using a
TensorDictReplayBufferinstance with the slice sampler>>> import torch >>> from tensordict import TensorDict >>> from torchrl.collectors.utils import split_trajectories >>> from torchrl.data import TensorDictReplayBuffer, ReplayBuffer, LazyTensorStorage, SliceSampler, SliceSamplerWithoutReplacement >>> >>> rb = TensorDictReplayBuffer(storage=LazyTensorStorage(max_size=1000), ... sampler=SliceSampler( ... slice_len=5, traj_key="episode",strict_length=False, ... )) ... >>> ep_1 = TensorDict( ... {"obs": torch.arange(100), ... "episode": torch.zeros(100),}, ... batch_size=[100] ... ) >>> ep_2 = TensorDict( ... {"obs": torch.arange(4), ... "episode": torch.ones(4),}, ... batch_size=[4] ... ) >>> rb.extend(ep_1) >>> rb.extend(ep_2) >>> >>> s = rb.sample(50) >>> print(s) TensorDict( fields={ episode: Tensor(shape=torch.Size([46]), device=cpu, dtype=torch.float32, is_shared=False), index: Tensor(shape=torch.Size([46, 1]), device=cpu, dtype=torch.int64, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([46, 1]), device=cpu, dtype=torch.bool, is_shared=False), terminated: Tensor(shape=torch.Size([46, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([46, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([46]), device=cpu, is_shared=False), obs: Tensor(shape=torch.Size([46]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([46]), device=cpu, is_shared=False) >>> t = split_trajectories(s, done_key="truncated") >>> print(t["obs"]) tensor([[73, 74, 75, 76, 77], [ 0, 1, 2, 3, 0], [ 0, 1, 2, 3, 0], [41, 42, 43, 44, 45], [ 0, 1, 2, 3, 0], [67, 68, 69, 70, 71], [27, 28, 29, 30, 31], [80, 81, 82, 83, 84], [17, 18, 19, 20, 21], [ 0, 1, 2, 3, 0]]) >>> print(t["episode"]) tensor([[0., 0., 0., 0., 0.], [1., 1., 1., 1., 0.], [1., 1., 1., 1., 0.], [0., 0., 0., 0., 0.], [1., 1., 1., 1., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [1., 1., 1., 1., 0.]])
using a
SliceSamplerWithoutReplacement>>> import torch >>> from tensordict import TensorDict >>> from torchrl.collectors.utils import split_trajectories >>> from torchrl.data import ReplayBuffer, LazyTensorStorage, SliceSampler, SliceSamplerWithoutReplacement >>> >>> rb = ReplayBuffer(storage=LazyTensorStorage(max_size=1000), ... sampler=SliceSamplerWithoutReplacement( ... slice_len=5, traj_key="episode",strict_length=False ... )) ... >>> ep_1 = TensorDict( ... {"obs": torch.arange(100), ... "episode": torch.zeros(100),}, ... batch_size=[100] ... ) >>> ep_2 = TensorDict( ... {"obs": torch.arange(4), ... "episode": torch.ones(4),}, ... batch_size=[4] ... ) >>> rb.extend(ep_1) >>> rb.extend(ep_2) >>> >>> s = rb.sample(50) >>> t = split_trajectories(s, trajectory_key="episode") >>> print(t["obs"]) tensor([[75, 76, 77, 78, 79], [ 0, 1, 2, 3, 0]]) >>> print(t["episode"]) tensor([[0., 0., 0., 0., 0.], [1., 1., 1., 1., 0.]])
Examples
>>> import torch >>> from tensordict import TensorDict >>> from torchrl.data.replay_buffers import LazyMemmapStorage, TensorDictReplayBuffer >>> from torchrl.data.replay_buffers.samplers import SliceSampler >>> torch.manual_seed(0) >>> rb = TensorDictReplayBuffer( ... storage=LazyMemmapStorage(1_000_000), ... sampler=SliceSampler(cache_values=True, num_slices=10), ... batch_size=320, ... ) >>> episode = torch.zeros(1000, dtype=torch.int) >>> episode[:300] = 1 >>> episode[300:550] = 2 >>> episode[550:700] = 3 >>> episode[700:] = 4 >>> data = TensorDict( ... { ... "episode": episode, ... "obs": torch.randn((3, 4, 5)).expand(1000, 3, 4, 5), ... "act": torch.randn((20,)).expand(1000, 20), ... "other": torch.randn((20, 50)).expand(1000, 20, 50), ... }, [1000] ... ) >>> rb.extend(data) >>> sample = rb.sample() >>> print("sample:", sample) >>> print("episodes", sample.get("episode").unique()) episodes tensor([1, 2, 3, 4], dtype=torch.int32)
SliceSampleris default-compatible with most of TorchRL’s datasets:Examples
>>> import torch >>> >>> from torchrl.data.datasets import RobosetExperienceReplay >>> from torchrl.data import SliceSampler >>> >>> torch.manual_seed(0) >>> num_slices = 10 >>> dataid = list(RobosetExperienceReplay.available_datasets)[0] >>> data = RobosetExperienceReplay(dataid, batch_size=320, sampler=SliceSampler(num_slices=num_slices)) >>> for batch in data: ... batch = batch.reshape(num_slices, -1) ... break >>> print("check that each batch only has one episode:", batch["episode"].unique(dim=1)) check that each batch only has one episode: tensor([[19], [14], [ 8], [10], [13], [ 4], [ 2], [ 3], [22], [ 8]])