.. currentmodule:: torchrl.collectors Collectors and Replay Buffers ============================= Collectors and replay buffers interoperability ---------------------------------------------- In the simplest scenario where single transitions have to be sampled from the replay buffer, little attention has to be given to the way the collector is built. Flattening the data after collection will be a sufficient preprocessing step before populating the storage: >>> memory = ReplayBuffer( ... storage=LazyTensorStorage(N), ... transform=lambda data: data.reshape(-1)) >>> for data in collector: ... memory.extend(data) If trajectory slices have to be collected, the recommended way to achieve this is to create a multidimensional buffer and sample using the :class:`~torchrl.data.replay_buffers.SliceSampler` sampler class. One must ensure that the data passed to the buffer is properly shaped, with the ``time`` and ``batch`` dimensions clearly separated. In practice, the following configurations will work: >>> # Single environment: no need for a multi-dimensional buffer >>> memory = ReplayBuffer( ... storage=LazyTensorStorage(N), ... sampler=SliceSampler(num_slices=4, trajectory_key=("collector", "traj_ids")) ... ) >>> collector = SyncDataCollector(env, policy, frames_per_batch=N, total_frames=-1) >>> for data in collector: ... memory.extend(data) >>> # Batched environments: a multi-dim buffer is required >>> memory = ReplayBuffer( ... storage=LazyTensorStorage(N, ndim=2), ... sampler=SliceSampler(num_slices=4, trajectory_key=("collector", "traj_ids")) ... ) >>> env = ParallelEnv(4, make_env) >>> collector = SyncDataCollector(env, policy, frames_per_batch=N, total_frames=-1) >>> for data in collector: ... memory.extend(data) >>> # MultiSyncDataCollector + regular env: behaves like a ParallelEnv if cat_results="stack" >>> memory = ReplayBuffer( ... storage=LazyTensorStorage(N, ndim=2), ... sampler=SliceSampler(num_slices=4, trajectory_key=("collector", "traj_ids")) ... ) >>> collector = MultiSyncDataCollector([make_env] * 4, ... policy, ... frames_per_batch=N, ... total_frames=-1, ... cat_results="stack") >>> for data in collector: ... memory.extend(data) >>> # MultiSyncDataCollector + parallel env: the ndim must be adapted accordingly >>> memory = ReplayBuffer( ... storage=LazyTensorStorage(N, ndim=3), ... sampler=SliceSampler(num_slices=4, trajectory_key=("collector", "traj_ids")) ... ) >>> collector = MultiSyncDataCollector([ParallelEnv(2, make_env)] * 4, ... policy, ... frames_per_batch=N, ... total_frames=-1, ... cat_results="stack") >>> for data in collector: ... memory.extend(data) Using replay buffers that sample trajectories with :class:`~torchrl.collectors.MultiSyncDataCollector` isn't currently fully supported as the data batches can come from any worker and in most cases consecutive batches written in the buffer won't come from the same source (thereby interrupting the trajectories). Helper functions ---------------- .. currentmodule:: torchrl.collectors.utils .. autosummary:: :toctree: generated/ :template: rl_template_fun.rst split_trajectories