Shortcuts

Collectors and Replay Buffers

Collectors and replay buffers interoperability

In the simplest scenario where single transitions have to be sampled from the replay buffer, little attention has to be given to the way the collector is built. Flattening the data after collection will be a sufficient preprocessing step before populating the storage:

>>> memory = ReplayBuffer(
...     storage=LazyTensorStorage(N),
...     transform=lambda data: data.reshape(-1))
>>> for data in collector:
...     memory.extend(data)

If trajectory slices have to be collected, the recommended way to achieve this is to create a multidimensional buffer and sample using the SliceSampler sampler class. One must ensure that the data passed to the buffer is properly shaped, with the time and batch dimensions clearly separated. In practice, the following configurations will work:

>>> # Single environment: no need for a multi-dimensional buffer
>>> memory = ReplayBuffer(
...     storage=LazyTensorStorage(N),
...     sampler=SliceSampler(num_slices=4, trajectory_key=("collector", "traj_ids"))
... )
>>> collector = SyncDataCollector(env, policy, frames_per_batch=N, total_frames=-1)
>>> for data in collector:
...     memory.extend(data)
>>> # Batched environments: a multi-dim buffer is required
>>> memory = ReplayBuffer(
...     storage=LazyTensorStorage(N, ndim=2),
...     sampler=SliceSampler(num_slices=4, trajectory_key=("collector", "traj_ids"))
... )
>>> env = ParallelEnv(4, make_env)
>>> collector = SyncDataCollector(env, policy, frames_per_batch=N, total_frames=-1)
>>> for data in collector:
...     memory.extend(data)
>>> # MultiSyncDataCollector + regular env: behaves like a ParallelEnv if cat_results="stack"
>>> memory = ReplayBuffer(
...     storage=LazyTensorStorage(N, ndim=2),
...     sampler=SliceSampler(num_slices=4, trajectory_key=("collector", "traj_ids"))
... )
>>> collector = MultiSyncDataCollector([make_env] * 4,
...     policy,
...     frames_per_batch=N,
...     total_frames=-1,
...     cat_results="stack")
>>> for data in collector:
...     memory.extend(data)
>>> # MultiSyncDataCollector + parallel env: the ndim must be adapted accordingly
>>> memory = ReplayBuffer(
...     storage=LazyTensorStorage(N, ndim=3),
...     sampler=SliceSampler(num_slices=4, trajectory_key=("collector", "traj_ids"))
... )
>>> collector = MultiSyncDataCollector([ParallelEnv(2, make_env)] * 4,
...     policy,
...     frames_per_batch=N,
...     total_frames=-1,
...     cat_results="stack")
>>> for data in collector:
...     memory.extend(data)

Using replay buffers that sample trajectories with MultiSyncDataCollector isn’t currently fully supported as the data batches can come from any worker and in most cases consecutive batches written in the buffer won’t come from the same source (thereby interrupting the trajectories).

Helper functions

split_trajectories(rollout_tensordict, *[, ...])

A util function for trajectory separation.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources