.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorials/rb_tutorial.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials_rb_tutorial.py: Using Replay Buffers ==================== **Author**: `Vincent Moens `_ .. _rb_tuto: .. GENERATED FROM PYTHON SOURCE LINES 11-57 Replay buffers are a central piece of any RL or control algorithm. Supervised learning methods are usually characterized by a training loop where data is randomly pulled from a static dataset and fed successively to the model and loss function. In RL, things are often slightly different: the data is gathered using the model, then temporarily stored in a dynamic structure (the experience replay buffer), which serves as dataset for the loss module. As always, the context in which the buffer is used drastically conditions how it is built: some may wish to store trajectories when others will want to store single transitions. Specific sampling strategies may be preferable in contexts: some items can have a higher priority than others, or it can be important to sample with or without replacement. Computational factors may also come into play, such as the size of the buffer which may exceed the available RAM storage. For these reasons, TorchRL's replay buffers are fully composable: although they come with "batteries included", requiring a minimal effort to be built, they also support many customizations such as storage type, sampling strategy or data transforms. In this tutorial, you will learn: - How to build a :ref:`Replay Buffer (RB) ` and use it with any datatype; - How to customize the :ref:`buffer's storage `; - How to use :ref:`RBs with TensorDict `; - How to :ref:`sample from or iterate over a replay buffer `, and how to define the sampling strategy; - How to use :ref:`prioritized replay buffers `; - How to :ref:`transform data ` coming in and out from the buffer; - How to store :ref:`trajectories ` in the buffer. Basics: building a vanilla replay buffer ---------------------------------------- .. _tuto_rb_vanilla: TorchRL's replay buffers are designed to prioritize modularity, composability, efficiency, and simplicity. For instance, creating a basic replay buffer is a straightforward process, as shown in the following example: .. GENERATED FROM PYTHON SOURCE LINES 57-66 .. code-block:: Python import tempfile from torchrl.data import ReplayBuffer buffer = ReplayBuffer() .. GENERATED FROM PYTHON SOURCE LINES 87-91 By default, this replay buffer will have a size of 1000. Let's check this by populating our buffer using the :meth:`~torchrl.data.ReplayBuffer.extend` method: .. GENERATED FROM PYTHON SOURCE LINES 91-98 .. code-block:: Python print("length before adding elements:", len(buffer)) buffer.extend(range(2000)) print("length after adding elements:", len(buffer)) .. rst-class:: sphx-glr-script-out .. code-block:: none length before adding elements: 0 length after adding elements: 1000 .. GENERATED FROM PYTHON SOURCE LINES 99-146 We have used the :meth:`~torchrl.data.ReplayBuffer.extend` method which is designed to add multiple items all at once. If the object that is passed to ``extend`` has more than one dimension, its first dimension is considered to be the one to be split in separate elements in the buffer. This essentially means that when adding multidimensional tensors or tensordicts to the buffer, the buffer will only look at the first dimension when counting the elements it holds in memory. If the object passed it not iterable, an exception will be thrown. To add items one at a time, the :meth:`~torchrl.data.ReplayBuffer.add` method should be used instead. Customizing the storage ----------------------- .. _tuto_rb_storage: We see that the buffer has been capped to the first 1000 elements that we passed to it. To change the size, we need to customize our storage. TorchRL proposes three types of storages: - The :class:`~torchrl.data.ListStorage` stores elements independently in a list. It supports any data type, but this flexibility comes at the cost of efficiency; - The :class:`~torchrl.data.LazyTensorStorage` stores tensors data structures contiguously. It works naturally with :class:`~tensordidct.TensorDict` (or :class:`~torchrl.data.tensorclass`) objects. The storage is contiguous on a per-tensor basis, meaning that sampling will be more efficient than when using a list, but the implicit restriction is that any data passed to it must have the same basic properties (such as shape and dtype) as the first batch of data that was used to instantiate the buffer. Passing data that does not match this requirement will either raise an exception or lead to some undefined behaviours. - The :class:`~torchrl.data.LazyMemmapStorage` works as the :class:`~torchrl.data.LazyTensorStorage` in that it is lazy (i.e., it expects the first batch of data to be instantiated), and it requires data that match in shape and dtype for each batch stored. What makes this storage unique is that it points to disk files (or uses the filesystem storage), meaning that it can support very large datasets while still accessing data in a contiguous manner. Let us see how we can use each of these storages: .. GENERATED FROM PYTHON SOURCE LINES 146-153 .. code-block:: Python from torchrl.data import LazyMemmapStorage, LazyTensorStorage, ListStorage # We define the maximum size of the buffer size = 100 .. GENERATED FROM PYTHON SOURCE LINES 154-156 A buffer with a list storage buffer can store any kind of data (but we must change the ``collate_fn`` since the default expects numerical data): .. GENERATED FROM PYTHON SOURCE LINES 156-160 .. code-block:: Python buffer_list = ReplayBuffer(storage=ListStorage(size), collate_fn=lambda x: x) buffer_list.extend(["a", 0, "b"]) print(buffer_list.sample(3)) .. rst-class:: sphx-glr-script-out .. code-block:: none ['a', 0, 'a'] .. GENERATED FROM PYTHON SOURCE LINES 161-167 Because it is the one with the lowest amount of assumption, the :class:`~torchrl.data.ListStorage` is the default storage in TorchRL. A :class:`~torchrl.data.LazyTensorStorage` can store data contiguously. This should be the preferred option when dealing with complicated but unchanging data structures of medium size: .. GENERATED FROM PYTHON SOURCE LINES 167-170 .. code-block:: Python buffer_lazytensor = ReplayBuffer(storage=LazyTensorStorage(size)) .. GENERATED FROM PYTHON SOURCE LINES 171-174 Let us create a batch of data of size ``torch.Size([3])` with 2 tensors stored in it: .. GENERATED FROM PYTHON SOURCE LINES 174-187 .. code-block:: Python import torch from tensordict import TensorDict data = TensorDict( { "a": torch.arange(12).view(3, 4), ("b", "c"): torch.arange(15).view(3, 5), }, batch_size=[3], ) print(data) .. rst-class:: sphx-glr-script-out .. code-block:: none TensorDict( fields={ a: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.int64, is_shared=False), b: TensorDict( fields={ c: Tensor(shape=torch.Size([3, 5]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 188-191 The first call to :meth:`~torchrl.data.ReplayBuffer.extend` will instantiate the storage. The first dimension of the data is unbound into separate datapoints: .. GENERATED FROM PYTHON SOURCE LINES 191-195 .. code-block:: Python buffer_lazytensor.extend(data) print(f"The buffer has {len(buffer_lazytensor)} elements") .. rst-class:: sphx-glr-script-out .. code-block:: none The buffer has 3 elements .. GENERATED FROM PYTHON SOURCE LINES 196-198 Let us sample from the buffer, and print the data: .. GENERATED FROM PYTHON SOURCE LINES 198-202 .. code-block:: Python sample = buffer_lazytensor.sample(5) print("samples", sample["a"], sample["b", "c"]) .. rst-class:: sphx-glr-script-out .. code-block:: none samples tensor([[ 8, 9, 10, 11], [ 8, 9, 10, 11], [ 8, 9, 10, 11], [ 8, 9, 10, 11], [ 0, 1, 2, 3]]) tensor([[10, 11, 12, 13, 14], [10, 11, 12, 13, 14], [10, 11, 12, 13, 14], [10, 11, 12, 13, 14], [ 0, 1, 2, 3, 4]]) .. GENERATED FROM PYTHON SOURCE LINES 203-205 A :class:`~torchrl.data.LazyMemmapStorage` is created in the same manner: .. GENERATED FROM PYTHON SOURCE LINES 205-212 .. code-block:: Python buffer_lazymemmap = ReplayBuffer(storage=LazyMemmapStorage(size)) buffer_lazymemmap.extend(data) print(f"The buffer has {len(buffer_lazymemmap)} elements") sample = buffer_lazytensor.sample(5) print("samples: a=", sample["a"], "\n('b', 'c'):", sample["b", "c"]) .. rst-class:: sphx-glr-script-out .. code-block:: none The buffer has 3 elements samples: a= tensor([[ 8, 9, 10, 11], [ 8, 9, 10, 11], [ 4, 5, 6, 7], [ 0, 1, 2, 3], [ 0, 1, 2, 3]]) ('b', 'c'): tensor([[10, 11, 12, 13, 14], [10, 11, 12, 13, 14], [ 5, 6, 7, 8, 9], [ 0, 1, 2, 3, 4], [ 0, 1, 2, 3, 4]]) .. GENERATED FROM PYTHON SOURCE LINES 213-215 We can also customize the storage location on disk: .. GENERATED FROM PYTHON SOURCE LINES 215-226 .. code-block:: Python tempdir = tempfile.TemporaryDirectory() buffer_lazymemmap = ReplayBuffer(storage=LazyMemmapStorage(size, scratch_dir=tempdir)) buffer_lazymemmap.extend(data) print(f"The buffer has {len(buffer_lazymemmap)} elements") print("the 'a' tensor is stored in", buffer_lazymemmap._storage._storage["a"].filename) print( "the ('b', 'c') tensor is stored in", buffer_lazymemmap._storage._storage["b", "c"].filename, ) .. rst-class:: sphx-glr-script-out .. code-block:: none The buffer has 3 elements the 'a' tensor is stored in /a.memmap the ('b', 'c') tensor is stored in /b/c.memmap .. GENERATED FROM PYTHON SOURCE LINES 227-246 Integration with TensorDict --------------------------- .. _tuto_rb_td: The tensor location follows the same structure as the TensorDict that contains them: this makes it easy to save and load buffers during training. To use :class:`~tensordict.TensorDict` as a data carrier at its fullest potential, the :class:`~torchrl.data.TensorDictReplayBuffer` class can be used. One of its key benefits is its ability to handle the organization of sampled data, along with any additional information that may be required (such as sample indices). It can be built in the same manner as a standard :class:`~torchrl.data.ReplayBuffer` and can generally be used interchangeably. .. GENERATED FROM PYTHON SOURCE LINES 246-259 .. code-block:: Python from torchrl.data import TensorDictReplayBuffer tempdir = tempfile.TemporaryDirectory() buffer_lazymemmap = TensorDictReplayBuffer( storage=LazyMemmapStorage(size, scratch_dir=tempdir), batch_size=12 ) buffer_lazymemmap.extend(data) print(f"The buffer has {len(buffer_lazymemmap)} elements") sample = buffer_lazymemmap.sample() print("sample:", sample) .. rst-class:: sphx-glr-script-out .. code-block:: none The buffer has 3 elements sample: TensorDict( fields={ a: Tensor(shape=torch.Size([12, 4]), device=cpu, dtype=torch.int64, is_shared=False), b: TensorDict( fields={ c: Tensor(shape=torch.Size([12, 5]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([12]), device=cpu, is_shared=False), index: Tensor(shape=torch.Size([12]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([12]), device=cpu, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 260-263 Our sample now has an extra ``"index"`` key that indicates what indices were sampled. Let us have a look at these indices: .. GENERATED FROM PYTHON SOURCE LINES 263-266 .. code-block:: Python print(sample["index"]) .. rst-class:: sphx-glr-script-out .. code-block:: none tensor([0, 2, 0, 0, 2, 0, 2, 2, 0, 1, 1, 2]) .. GENERATED FROM PYTHON SOURCE LINES 267-273 Integration with tensorclass ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ The ReplayBuffer class and associated subclasses also work natively with :class:`~tensordict.tensorclass` classes, which can conveniently be used to encode datasets in a more explicit manner: .. GENERATED FROM PYTHON SOURCE LINES 273-302 .. code-block:: Python from tensordict import tensorclass @tensorclass class MyData: images: torch.Tensor labels: torch.Tensor data = MyData( images=torch.randint( 255, (10, 64, 64, 3), ), labels=torch.randint(100, (10,)), batch_size=[10], ) tempdir = tempfile.TemporaryDirectory() buffer_lazymemmap = ReplayBuffer( storage=LazyMemmapStorage(size, scratch_dir=tempdir), batch_size=12 ) buffer_lazymemmap.extend(data) print(f"The buffer has {len(buffer_lazymemmap)} elements") sample = buffer_lazymemmap.sample() print("sample:", sample) .. rst-class:: sphx-glr-script-out .. code-block:: none The buffer has 10 elements sample: MyData( images=Tensor(shape=torch.Size([12, 64, 64, 3]), device=cpu, dtype=torch.int64, is_shared=False), labels=Tensor(shape=torch.Size([12]), device=cpu, dtype=torch.int64, is_shared=False), batch_size=torch.Size([12]), device=cpu, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 303-320 As expected. the data has the proper class and shape! Integration with other tensor structures (PyTrees) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ TorchRL's replay buffers also work with any pytree data structure. A PyTree is a nested structure of arbitrary depth made of dicts, lists and/or tuples where the leaves are tensors. This means that one can store in contiguous memory any such tree structure! Various storages can be used: :class:`~torchrl.data.replay_buffers.TensorStorage`, :class:`~torchrl.data.replay_buffers.LazyMemmapStorage` or :class:`~torchrl.data.replay_buffers.LazyTensorStorage` all accept this kind of data. Here is a brief demonstration of what this feature looks like: .. GENERATED FROM PYTHON SOURCE LINES 320-324 .. code-block:: Python from torch.utils._pytree import tree_map .. GENERATED FROM PYTHON SOURCE LINES 325-326 Let's build our replay buffer on disk: .. GENERATED FROM PYTHON SOURCE LINES 326-338 .. code-block:: Python rb = ReplayBuffer(storage=LazyMemmapStorage(size)) data = { "a": torch.randn(3), "b": {"c": (torch.zeros(2), [torch.ones(1)])}, 30: -torch.ones(()), # non-string keys also work } rb.add(data) # The sample has a similar structure to the data (with a leading dimension of 10 for each tensor) sample = rb.sample(10) .. GENERATED FROM PYTHON SOURCE LINES 339-340 With pytrees, any callable can be used as a transform: .. GENERATED FROM PYTHON SOURCE LINES 340-351 .. code-block:: Python def transform(x): # Zeros all the data in the pytree return tree_map(lambda y: y * 0, x) rb.append_transform(transform) sample = rb.sample(batch_size=12) .. GENERATED FROM PYTHON SOURCE LINES 352-353 let's check that our transform did its job: .. GENERATED FROM PYTHON SOURCE LINES 353-360 .. code-block:: Python def assert0(x): assert (x == 0).all() tree_map(assert0, sample) .. rst-class:: sphx-glr-script-out .. code-block:: none {'a': None, 'b': {'c': (None, [None])}, 30: None} .. GENERATED FROM PYTHON SOURCE LINES 361-387 Sampling and iterating over buffers ----------------------------------- .. _tuto_rb_sampling: Replay Buffers support multiple sampling strategies: - If the batch-size is fixed and can be defined at construction time, it can be passed as keyword argument to the buffer; - With a fixed batch-size, the replay buffer can be iterated over to gather samples; - If the batch-size is dynamic, it can be passed to the :class:`~torchrl.data.ReplayBuffer.sample` method on-the-fly. Sampling can be done using multithreading, but this is incompatible with the last option (at it requires the buffer to know in advance the size of the next batch). Let us see a few examples: Fixed batch-size ~~~~~~~~~~~~~~~~ If the batch-size is passed during construction, it should be omitted when sampling: .. GENERATED FROM PYTHON SOURCE LINES 387-402 .. code-block:: Python data = MyData( images=torch.randint( 255, (200, 64, 64, 3), ), labels=torch.randint(100, (200,)), batch_size=[200], ) buffer_lazymemmap = ReplayBuffer(storage=LazyMemmapStorage(size), batch_size=128) buffer_lazymemmap.extend(data) buffer_lazymemmap.sample() .. rst-class:: sphx-glr-script-out .. code-block:: none MyData( images=Tensor(shape=torch.Size([128, 64, 64, 3]), device=cpu, dtype=torch.int64, is_shared=False), labels=Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.int64, is_shared=False), batch_size=torch.Size([128]), device=cpu, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 403-409 This batch of data has the size that we wanted it to have (128). To enable multithreaded sampling, just pass a positive integer to the ``prefetch`` keyword argument during construction. This should speed up sampling considerably whenever sampling is time consuming (e.g., when using prioritized samplers): .. GENERATED FROM PYTHON SOURCE LINES 409-418 .. code-block:: Python buffer_lazymemmap = ReplayBuffer( storage=LazyMemmapStorage(size), batch_size=128, prefetch=10 ) # creates a queue of 10 elements to be prefetched in the background buffer_lazymemmap.extend(data) print(buffer_lazymemmap.sample()) .. rst-class:: sphx-glr-script-out .. code-block:: none MyData( images=Tensor(shape=torch.Size([128, 64, 64, 3]), device=cpu, dtype=torch.int64, is_shared=False), labels=Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.int64, is_shared=False), batch_size=torch.Size([128]), device=cpu, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 419-424 Iterating over the buffer with a fixed batch-size ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ We can also iterate over the buffer like we would do with a regular dataloader, as long as the batch-size is predefined: .. GENERATED FROM PYTHON SOURCE LINES 424-432 .. code-block:: Python for i, data in enumerate(buffer_lazymemmap): if i == 3: print(data) break .. rst-class:: sphx-glr-script-out .. code-block:: none MyData( images=Tensor(shape=torch.Size([128, 64, 64, 3]), device=cpu, dtype=torch.int64, is_shared=False), labels=Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.int64, is_shared=False), batch_size=torch.Size([128]), device=cpu, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 433-439 Due to the fact that our sampling technique is entirely random and does not prevent replacement, the iterator in question is infinite. However, we can make use of the :class:`~torchrl.data.replay_buffers.SamplerWithoutReplacement` instead, which will transform our buffer into a finite iterator: .. GENERATED FROM PYTHON SOURCE LINES 439-445 .. code-block:: Python from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement buffer_lazymemmap = ReplayBuffer( storage=LazyMemmapStorage(size), batch_size=32, sampler=SamplerWithoutReplacement() ) .. GENERATED FROM PYTHON SOURCE LINES 446-447 we create a data that is big enough to get a couple of samples .. GENERATED FROM PYTHON SOURCE LINES 447-461 .. code-block:: Python data = TensorDict( { "a": torch.arange(64).view(16, 4), ("b", "c"): torch.arange(128).view(16, 8), }, batch_size=[16], ) buffer_lazymemmap.extend(data) for _i, _ in enumerate(buffer_lazymemmap): continue print(f"A total of {_i+1} batches have been collected") .. rst-class:: sphx-glr-script-out .. code-block:: none A total of 1 batches have been collected .. GENERATED FROM PYTHON SOURCE LINES 462-467 Dynamic batch-size ~~~~~~~~~~~~~~~~~~ In contrast to what we have seen earlier, the ``batch_size`` keyword argument can be omitted and passed directly to the ``sample`` method: .. GENERATED FROM PYTHON SOURCE LINES 467-476 .. code-block:: Python buffer_lazymemmap = ReplayBuffer( storage=LazyMemmapStorage(size), sampler=SamplerWithoutReplacement() ) buffer_lazymemmap.extend(data) print("sampling 3 elements:", buffer_lazymemmap.sample(3)) print("sampling 5 elements:", buffer_lazymemmap.sample(5)) .. rst-class:: sphx-glr-script-out .. code-block:: none sampling 3 elements: TensorDict( fields={ a: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.int64, is_shared=False), b: TensorDict( fields={ c: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([3]), device=cpu, is_shared=False)}, batch_size=torch.Size([3]), device=cpu, is_shared=False) sampling 5 elements: TensorDict( fields={ a: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False), b: TensorDict( fields={ c: Tensor(shape=torch.Size([5, 8]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([5]), device=cpu, is_shared=False)}, batch_size=torch.Size([5]), device=cpu, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 477-494 Prioritized Replay buffers -------------------------- .. _tuto_rb_prb: TorchRL also provides an interface for `prioritized replay buffers `_. This buffer class samples data according to a priority signal that is passed through the data. Although this tool is compatible with non-tensordict data, we encourage using TensorDict instead as it makes it possible to carry meta-data in and out from the buffer with little effort. Let us first see how to build a prioritized replay buffer in the generic case. The :math:`\alpha` and :math:`\beta` hyperparameters have to be manually set: .. GENERATED FROM PYTHON SOURCE LINES 494-506 .. code-block:: Python from torchrl.data.replay_buffers.samplers import PrioritizedSampler size = 100 rb = ReplayBuffer( storage=ListStorage(size), sampler=PrioritizedSampler(max_capacity=size, alpha=0.8, beta=1.1), collate_fn=lambda x: x, ) .. GENERATED FROM PYTHON SOURCE LINES 507-509 Extending the replay buffer returns the items indices, which we will need later to update the priority: .. GENERATED FROM PYTHON SOURCE LINES 509-512 .. code-block:: Python indices = rb.extend([1, "foo", None]) .. GENERATED FROM PYTHON SOURCE LINES 513-522 The sampler expects to have a priority for each element. When added to the buffer, the priority is set to a default value of 1. Once the priority has been computed (usually through the loss), it must be updated in the buffer. This is done via the :meth:`~torchrl.data.ReplayBuffer.update_priority` method, which requires the indices as well as the priority. We assign an artificially high priority to the second sample in the dataset to observe its effect on sampling: .. GENERATED FROM PYTHON SOURCE LINES 522-524 .. code-block:: Python rb.update_priority(index=indices, priority=torch.tensor([0, 1_000, 0.1])) .. GENERATED FROM PYTHON SOURCE LINES 525-528 We observe that sampling from the buffer returns mostly the second sample (``"foo"``): .. GENERATED FROM PYTHON SOURCE LINES 528-532 .. code-block:: Python sample, info = rb.sample(10, return_info=True) print(sample) .. rst-class:: sphx-glr-script-out .. code-block:: none ['foo', 'foo', 'foo', 'foo', 'foo', 'foo', 'foo', 'foo', 'foo', 'foo'] .. GENERATED FROM PYTHON SOURCE LINES 533-534 The info contains the relative weights of the items as well as the indices. .. GENERATED FROM PYTHON SOURCE LINES 534-537 .. code-block:: Python print(info) .. rst-class:: sphx-glr-script-out .. code-block:: none {'_weight': tensor([2.0893e-10, 2.0893e-10, 2.0893e-10, 2.0893e-10, 2.0893e-10, 2.0893e-10, 2.0893e-10, 2.0893e-10, 2.0893e-10, 2.0893e-10]), 'index': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])} .. GENERATED FROM PYTHON SOURCE LINES 538-576 We see that using a prioritized replay buffer requires a series of extra steps in the training loop compared with a regular buffer: - After collecting data and extending the buffer, the priority of the items must be updated; - After computing the loss and getting a "priority signal" from it, we must update again the priority of the items in the buffer. This requires us to keep track of the indices. This drastically hampers the reusability of the buffer: if one is to write a training script where both a prioritized and a regular buffer can be created, she must add a considerable amount of control flow to make sure that the appropriate methods are called at the appropriate place, if and only if a prioritized buffer is being used. Let us see how we can improve this with :class:`~tensordict.TensorDict`. We saw that the :class:`~torchrl.data.TensorDictReplayBuffer` returns data augmented with their relative storage indices. One feature we did not mention is that this class also ensures that the priority signal is automatically parsed to the prioritized sampler if present during extension. The combination of these features simplifies things in several ways: - When extending the buffer, the priority signal will automatically be parsed if present and the priority will accurately be assigned; - The indices will be stored in the sampled tensordicts, making it easy to update the priority after the loss computation. - When computing the loss, the priority signal will be registered in the tensordict passed to the loss module, making it possible to update the weights without effort: >>> data = replay_buffer.sample() >>> loss_val = loss_module(data) >>> replay_buffer.update_tensordict_priority(data) The following code illustrates these concepts. We build a replay buffer with a prioritized sampler, and indicate in the constructor the entry where the priority signal should be fetched: .. GENERATED FROM PYTHON SOURCE LINES 576-585 .. code-block:: Python rb = TensorDictReplayBuffer( storage=ListStorage(size), sampler=PrioritizedSampler(size, alpha=0.8, beta=1.1), priority_key="td_error", batch_size=1024, ) .. GENERATED FROM PYTHON SOURCE LINES 586-588 Let us choose a priority signal that is proportional to the storage index: .. GENERATED FROM PYTHON SOURCE LINES 588-594 .. code-block:: Python data["td_error"] = torch.arange(data.numel()) rb.extend(data) sample = rb.sample() .. GENERATED FROM PYTHON SOURCE LINES 595-596 higher indices should occur more frequently: .. GENERATED FROM PYTHON SOURCE LINES 596-601 .. code-block:: Python from matplotlib import pyplot as plt plt.hist(sample["index"].numpy()) .. image-sg:: /tutorials/images/sphx_glr_rb_tutorial_001.png :alt: rb tutorial :srcset: /tutorials/images/sphx_glr_rb_tutorial_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none (array([ 30., 30., 82., 61., 68., 160., 84., 161., 101., 247.]), array([ 1. , 2.4, 3.8, 5.2, 6.6, 8. , 9.4, 10.8, 12.2, 13.6, 15. ]), ) .. GENERATED FROM PYTHON SOURCE LINES 602-608 Once we have worked with our sample, we update the priority key using the :meth:`torchrl.data.TensorDictReplayBuffer.update_tensordict_priority` method. For the sake of showing how this works, let us revert the priority of the sampled items: .. GENERATED FROM PYTHON SOURCE LINES 608-612 .. code-block:: Python sample = rb.sample() sample["td_error"] = data.numel() - sample["index"] rb.update_tensordict_priority(sample) .. GENERATED FROM PYTHON SOURCE LINES 613-614 Now, higher indices should occur less frequently: .. GENERATED FROM PYTHON SOURCE LINES 614-620 .. code-block:: Python sample = rb.sample() from matplotlib import pyplot as plt plt.hist(sample["index"].numpy()) .. image-sg:: /tutorials/images/sphx_glr_rb_tutorial_002.png :alt: rb tutorial :srcset: /tutorials/images/sphx_glr_rb_tutorial_002.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none (array([220., 117., 196., 93., 73., 114., 58., 93., 29., 31.]), array([ 1. , 2.4, 3.8, 5.2, 6.6, 8. , 9.4, 10.8, 12.2, 13.6, 15. ]), ) .. GENERATED FROM PYTHON SOURCE LINES 621-646 Using transforms ---------------- .. _tuto_rb_transform: The data stored in a replay buffer may not be ready to be presented to a loss module. In some cases, the data produced by a collector can be too heavy to be saved as-is. Examples of this include converting images from ``uint8`` to floating point tensors, or concatenating successive frames when using decision transformers. Data can be processed in and out of a buffer just by appending the appropriate transform to it. Here are a few examples: Saving raw images ~~~~~~~~~~~~~~~~~ ``uint8``-typed tensors are comparatively much less memory expensive than the floating point tensors we usually feed to our models. For this reason, it can be useful to save the raw images. The following script show how one can build a collector that returns only the raw images but uses the transformed ones for inference, and how these transformations can be recycled in the replay buffer: .. GENERATED FROM PYTHON SOURCE LINES 646-668 .. code-block:: Python from torchrl.collectors import SyncDataCollector from torchrl.envs.libs.gym import GymEnv from torchrl.envs.transforms import ( Compose, GrayScale, Resize, ToTensorImage, TransformedEnv, ) from torchrl.envs.utils import RandomPolicy env = TransformedEnv( GymEnv("CartPole-v1", from_pixels=True), Compose( ToTensorImage(in_keys=["pixels"], out_keys=["pixels_trsf"]), Resize(in_keys=["pixels_trsf"], w=64, h=64), GrayScale(in_keys=["pixels_trsf"]), ), ) .. GENERATED FROM PYTHON SOURCE LINES 669-670 let us have a look at a rollout: .. GENERATED FROM PYTHON SOURCE LINES 670-674 .. code-block:: Python print(env.rollout(3)) .. rst-class:: sphx-glr-script-out .. code-block:: none TensorDict( fields={ action: Tensor(shape=torch.Size([3, 2]), device=cpu, dtype=torch.int64, is_shared=False), done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), pixels: Tensor(shape=torch.Size([3, 400, 600, 3]), device=cpu, dtype=torch.uint8, is_shared=False), pixels_trsf: Tensor(shape=torch.Size([3, 1, 64, 64]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([3]), device=cpu, is_shared=False), pixels: Tensor(shape=torch.Size([3, 400, 600, 3]), device=cpu, dtype=torch.uint8, is_shared=False), pixels_trsf: Tensor(shape=torch.Size([3, 1, 64, 64]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([3]), device=cpu, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 675-680 We have just created an environment that produces pixels. These images are processed to be fed to a policy. We would like to store the raw images, and not their transforms. To do this, we will append a transform to the collector to select the keys we want to see appearing: .. GENERATED FROM PYTHON SOURCE LINES 680-692 .. code-block:: Python from torchrl.envs.transforms import ExcludeTransform collector = SyncDataCollector( env, RandomPolicy(env.action_spec), frames_per_batch=10, total_frames=1000, postproc=ExcludeTransform("pixels_trsf", ("next", "pixels_trsf"), "collector"), ) .. GENERATED FROM PYTHON SOURCE LINES 693-695 Let us have a look at a batch of data, and control that the ``"pixels_trsf"`` keys have been discarded: .. GENERATED FROM PYTHON SOURCE LINES 695-702 .. code-block:: Python for data in collector: print(data) break .. rst-class:: sphx-glr-script-out .. code-block:: none TensorDict( fields={ action: Tensor(shape=torch.Size([10, 2]), device=cpu, dtype=torch.int64, is_shared=False), done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False), pixels: Tensor(shape=torch.Size([10, 400, 600, 3]), device=cpu, dtype=torch.uint8, is_shared=False), reward: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([10]), device=None, is_shared=False), pixels: Tensor(shape=torch.Size([10, 400, 600, 3]), device=cpu, dtype=torch.uint8, is_shared=False), terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([10]), device=None, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 703-712 We create a replay buffer with the same transform as the environment. There is, however, a detail that needs to be addressed: transforms used without environments are oblivious to the data structure. When appending a transform to an environment, the data in the ``"next"`` nested tensordict is transformed first and then copied at the root during the rollout execution. When working with static data, this is not the case. Nevertheless, our data comes with a nested "next" tensordict that will be ignored by our transform if we don't explicitly instruct it to take care of it. We manually add these keys to the transform: .. GENERATED FROM PYTHON SOURCE LINES 712-726 .. code-block:: Python t = Compose( ToTensorImage( in_keys=["pixels", ("next", "pixels")], out_keys=["pixels_trsf", ("next", "pixels_trsf")], ), Resize(in_keys=["pixels_trsf", ("next", "pixels_trsf")], w=64, h=64), GrayScale(in_keys=["pixels_trsf", ("next", "pixels_trsf")]), ) rb = TensorDictReplayBuffer(storage=LazyMemmapStorage(1000), transform=t, batch_size=16) rb.extend(data) .. rst-class:: sphx-glr-script-out .. code-block:: none tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) .. GENERATED FROM PYTHON SOURCE LINES 727-729 We can check that a ``sample`` method sees the transformed images reappear: .. GENERATED FROM PYTHON SOURCE LINES 729-732 .. code-block:: Python print(rb.sample()) .. rst-class:: sphx-glr-script-out .. code-block:: none TensorDict( fields={ action: Tensor(shape=torch.Size([16, 2]), device=cpu, dtype=torch.int64, is_shared=False), done: Tensor(shape=torch.Size([16, 1]), device=cpu, dtype=torch.bool, is_shared=False), index: Tensor(shape=torch.Size([16]), device=cpu, dtype=torch.int64, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([16, 1]), device=cpu, dtype=torch.bool, is_shared=False), pixels: Tensor(shape=torch.Size([16, 400, 600, 3]), device=cpu, dtype=torch.uint8, is_shared=False), pixels_trsf: Tensor(shape=torch.Size([16, 1, 64, 64]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([16, 1]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([16, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([16, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([16]), device=cpu, is_shared=False), pixels: Tensor(shape=torch.Size([16, 400, 600, 3]), device=cpu, dtype=torch.uint8, is_shared=False), pixels_trsf: Tensor(shape=torch.Size([16, 1, 64, 64]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([16, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([16, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([16]), device=cpu, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 733-744 A more complex examples: using CatFrames ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ The :class:`~torchrl.envs.transforms.CatFrames` transform unfolds the observations through time, creating a n-back memory of past events that allow the model to take the past events into account (in the case of POMDPs or with recurrent policies such as Decision Transformers). Storing these concatenated frames can consume a considerable amount of memory. It can also be problematic when the n-back window needs to be different (usually longer) during training and inference. We solve this problem by executing the ``CatFrames`` transform separately in the two phases. .. GENERATED FROM PYTHON SOURCE LINES 744-747 .. code-block:: Python from torchrl.envs import CatFrames, UnsqueezeTransform .. GENERATED FROM PYTHON SOURCE LINES 748-750 We create a standard list of transforms for environments that return pixel-based observations: .. GENERATED FROM PYTHON SOURCE LINES 750-770 .. code-block:: Python env = TransformedEnv( GymEnv("CartPole-v1", from_pixels=True), Compose( ToTensorImage(in_keys=["pixels"], out_keys=["pixels_trsf"]), Resize(in_keys=["pixels_trsf"], w=64, h=64), GrayScale(in_keys=["pixels_trsf"]), UnsqueezeTransform(-4, in_keys=["pixels_trsf"]), CatFrames(dim=-4, N=4, in_keys=["pixels_trsf"]), ), ) collector = SyncDataCollector( env, RandomPolicy(env.action_spec), frames_per_batch=10, total_frames=1000, ) for data in collector: print(data) break .. rst-class:: sphx-glr-script-out .. code-block:: none TensorDict( fields={ action: Tensor(shape=torch.Size([10, 2]), device=cpu, dtype=torch.int64, is_shared=False), collector: TensorDict( fields={ traj_ids: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([10]), device=None, is_shared=False), done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False), pixels: Tensor(shape=torch.Size([10, 400, 600, 3]), device=cpu, dtype=torch.uint8, is_shared=False), pixels_trsf: Tensor(shape=torch.Size([10, 4, 1, 64, 64]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([10]), device=None, is_shared=False), pixels: Tensor(shape=torch.Size([10, 400, 600, 3]), device=cpu, dtype=torch.uint8, is_shared=False), pixels_trsf: Tensor(shape=torch.Size([10, 4, 1, 64, 64]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([10]), device=None, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 771-774 The buffer transform looks pretty much like the environment one, but with extra ``("next", ...)`` keys like before: .. GENERATED FROM PYTHON SOURCE LINES 774-789 .. code-block:: Python t = Compose( ToTensorImage( in_keys=["pixels", ("next", "pixels")], out_keys=["pixels_trsf", ("next", "pixels_trsf")], ), Resize(in_keys=["pixels_trsf", ("next", "pixels_trsf")], w=64, h=64), GrayScale(in_keys=["pixels_trsf", ("next", "pixels_trsf")]), UnsqueezeTransform(-4, in_keys=["pixels_trsf", ("next", "pixels_trsf")]), CatFrames(dim=-4, N=4, in_keys=["pixels_trsf", ("next", "pixels_trsf")]), ) rb = TensorDictReplayBuffer(storage=LazyMemmapStorage(size), transform=t, batch_size=16) data_exclude = data.exclude("pixels_trsf", ("next", "pixels_trsf")) rb.add(data_exclude) .. rst-class:: sphx-glr-script-out .. code-block:: none 0 .. GENERATED FROM PYTHON SOURCE LINES 790-794 Let us sample one batch from the buffer. The shape of the transformed pixel keys should have a length of 4 along the 4th dimension starting from the end: .. GENERATED FROM PYTHON SOURCE LINES 794-798 .. code-block:: Python s = rb.sample(1) # the buffer has only one element print(s) .. rst-class:: sphx-glr-script-out .. code-block:: none TensorDict( fields={ action: Tensor(shape=torch.Size([1, 10, 2]), device=cpu, dtype=torch.int64, is_shared=False), collector: TensorDict( fields={ traj_ids: Tensor(shape=torch.Size([1, 10]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([1, 10]), device=cpu, is_shared=False), done: Tensor(shape=torch.Size([1, 10, 1]), device=cpu, dtype=torch.bool, is_shared=False), index: Tensor(shape=torch.Size([1, 10]), device=cpu, dtype=torch.int64, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([1, 10, 1]), device=cpu, dtype=torch.bool, is_shared=False), pixels: Tensor(shape=torch.Size([1, 10, 400, 600, 3]), device=cpu, dtype=torch.uint8, is_shared=False), pixels_trsf: Tensor(shape=torch.Size([1, 10, 4, 1, 64, 64]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([1, 10, 1]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([1, 10, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([1, 10, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([1, 10]), device=cpu, is_shared=False), pixels: Tensor(shape=torch.Size([1, 10, 400, 600, 3]), device=cpu, dtype=torch.uint8, is_shared=False), pixels_trsf: Tensor(shape=torch.Size([1, 10, 4, 1, 64, 64]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([1, 10, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([1, 10, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([1, 10]), device=cpu, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 799-801 After a bit of processing (excluding non-used keys etc), we see that the data generated online and offline match! .. GENERATED FROM PYTHON SOURCE LINES 801-804 .. code-block:: Python assert (data.exclude("collector") == s.squeeze(0).exclude("index", "collector")).all() .. GENERATED FROM PYTHON SOURCE LINES 805-825 Storing trajectories -------------------- .. _tuto_rb_traj: In many cases, it is desirable to access trajectories from the buffer rather than simple transitions. TorchRL offers multiple ways of achieving this. The preferred way is currently to store trajectories along the first dimension of the buffer and use a :class:`~torchrl.data.SliceSampler` to sample these batches of data. This class only needs a couple of information about your data structure to do its job (not that as of now it is only compatible with tensordict-structured data): the number of slices or their length and some information about where the separation between the episodes can be found (e.g. :ref:`recall that ` with a :ref:`DataCollector `, the trajectory id is stored in ``("collector", "traj_ids")``). In this simple example, we construct a data with 4 consecutive short trajectories and sample 4 slices out of it, each of length 2 (since the batch size is 8, and 8 items // 4 slices = 2 time steps). We mark the steps as well. .. GENERATED FROM PYTHON SOURCE LINES 825-854 .. code-block:: Python from torchrl.data import SliceSampler rb = TensorDictReplayBuffer( storage=LazyMemmapStorage(size), sampler=SliceSampler(traj_key="episode", num_slices=4), batch_size=8, ) episode = torch.zeros(10, dtype=torch.int) episode[:3] = 1 episode[3:5] = 2 episode[5:7] = 3 episode[7:] = 4 steps = torch.cat([torch.arange(3), torch.arange(2), torch.arange(2), torch.arange(3)]) data = TensorDict( { "episode": episode, "obs": torch.randn((3, 4, 5)).expand(10, 3, 4, 5), "act": torch.randn((20,)).expand(10, 20), "other": torch.randn((20, 50)).expand(10, 20, 50), "steps": steps, }, [10], ) rb.extend(data) sample = rb.sample() print("episode are grouped", sample["episode"]) print("steps are successive", sample["steps"]) .. rst-class:: sphx-glr-script-out .. code-block:: none episode are grouped tensor([4, 4, 4, 4, 1, 1, 4, 4], dtype=torch.int32) steps are successive tensor([1, 2, 1, 2, 1, 2, 0, 1]) .. GENERATED FROM PYTHON SOURCE LINES 855-877 Conclusion ---------- We have seen how a replay buffer can be used in TorchRL, from its simplest usage to more advanced ones where the data need to be transformed or stored in particular ways. You should now be able to: - Create a Replay Buffer, customize its storage, sampler and transforms; - Choose the best storage type for your problem (list, memory or disk-based); - Minimize the memory footprint of your buffer. Next steps ---------- - Check the data API reference to learn about offline datasets in TorchRL, which are based on our Replay Buffer API; - Check other samplers such as :class:`~torchrl.data.SamplerWithoutReplacement`, :class:`~torchrl.data.PrioritizedSliceSampler` and :class:`~torchrl.data.SliceSamplerWithoutReplacement`, or other writers such as :class:`~torchrl.data.TensorDictMaxValueWriter`. .. rst-class:: sphx-glr-timing **Total running time of the script:** (2 minutes 49.025 seconds) **Estimated memory usage:** 211 MB .. _sphx_glr_download_tutorials_rb_tutorial.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: rb_tutorial.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: rb_tutorial.py ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_