Replay Buffers#
Replay buffers are a central part of off-policy RL algorithms. TorchRL provides an efficient implementation of a few, widely used replay buffers:
Core Replay Buffer Classes#
|
A generic, composable replay buffer class. |
|
An ensemble of replay buffers. |
|
Prioritized replay buffer. |
|
TensorDict-specific wrapper around the |
|
TensorDict-specific wrapper around the |
|
A Ray implementation of the Replay Buffer that can be extended and sampled remotely. |
|
A remote invocation friendly ReplayBuffer class. |
Composable Replay Buffers#
We also give users the ability to compose a replay buffer. We provide a wide panel of solutions for replay buffer usage, including support for almost any data type; storage in memory, on device or on physical memory; several sampling strategies; usage of transforms etc.
Supported data types and choosing a storage#
In theory, replay buffers support any data type but we can’t guarantee that each
component will support any data type. The most crude replay buffer implementation
is made of a ReplayBuffer base with a
ListStorage storage. This is very inefficient
but it will allow you to store complex data structures with non-tensor data.
Storages in contiguous memory include TensorStorage,
LazyTensorStorage and
LazyMemmapStorage.
Sampling and indexing#
Replay buffers can be indexed and sampled.
Indexing and sampling collect data at given indices in the storage and then process them
through a series of transforms and collate_fn that can be passed to the __init__
function of the replay buffer.
The full physical storage can be read with rb[:]. This is useful when all
stored items must be processed in storage order, for example to recompute value
targets after collection. read_all_in_order()
is an explicit equivalent to rb[:], and
write_all() is an explicit equivalent to
rb[:] = data. Passing end=... to these helpers updates only the leading
storage entries.
>>> from tensordict import TensorDict
>>> import torch
>>> from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer
>>> rb = TensorDictReplayBuffer(storage=LazyTensorStorage(10))
>>> rb.extend(TensorDict({"obs": torch.arange(3)}, [3]))
tensor([0, 1, 2])
>>> data = rb.read_all_in_order()
>>> assert (data == rb[:]).all()
>>> data["target"] = data["obs"] + 1
>>> rb.write_all(data)
>>> assert (rb[:] == data).all()
TED-format conversion#
The following helpers convert between the TorchRL Episode Data (TED) layout and a flat, storage-friendly representation when serializing or restoring a buffer: