.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorials/collector_trajectory_assembly.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials_collector_trajectory_assembly.py: Collectors Deep Dive: Trajectory Assembly ========================================== **Author**: `Vincent Moens `_ .. _collector_trajectory_assembly: .. grid:: 2 .. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn * Why collectors return fixed-size batches that mix multiple trajectories * How ``split_trajectories()`` reassembles them into padded, per-episode tensors * What ``("collector", "traj_ids")`` and ``("collector", "mask")`` mean * How ``done`` and ``truncated`` interact with trajectory splitting * When to use ``as_nested=True`` for memory-efficient ragged batches * How to request complete trajectories with ``trajs_per_batch`` * How to store complete trajectories in a replay buffer .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites * `TorchRL `_ and `gymnasium `_ installed * Familiarity with :class:`~torchrl.collectors.SyncDataCollector` (see :ref:`the data-collection tutorial `) .. GENERATED FROM PYTHON SOURCE LINES 28-39 .. code-block:: Python import torch from torchrl.collectors import SyncDataCollector from torchrl.collectors.utils import split_trajectories from torchrl.data import LazyTensorStorage, ReplayBuffer from torchrl.envs import GymEnv from torchrl.modules import RandomPolicy torch.manual_seed(0) .. rst-class:: sphx-glr-script-out .. code-block:: none .. GENERATED FROM PYTHON SOURCE LINES 45-56 Why collectors return fixed-size chunks --------------------------------------- In reinforcement learning, episodes can have wildly different lengths. A CartPole episode may last 10 steps or 500, depending on the policy. To keep training loops predictable, TorchRL collectors always return batches of exactly ``frames_per_batch`` transitions, regardless of how many episodes those transitions span. This means a single batch will typically contain **fragments of multiple trajectories** stitched together. Let's see this in practice. .. GENERATED FROM PYTHON SOURCE LINES 56-67 .. code-block:: Python env = GymEnv("CartPole-v1") env.set_seed(0) policy = RandomPolicy(env.action_spec) collector = SyncDataCollector(env, policy, frames_per_batch=200, total_frames=-1) for data in collector: print(data) break .. rst-class:: sphx-glr-script-out .. code-block:: none TensorDict( fields={ action: Tensor(shape=torch.Size([200, 2]), device=cpu, dtype=torch.int64, is_shared=False), collector: TensorDict( fields={ traj_ids: Tensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([200]), device=None, is_shared=False), done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([200, 4]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([200]), device=None, is_shared=False), observation: Tensor(shape=torch.Size([200, 4]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([200]), device=None, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 68-70 The batch has exactly 200 transitions. Let's inspect its trajectory IDs — each integer labels which episode a given transition belongs to: .. GENERATED FROM PYTHON SOURCE LINES 70-73 .. code-block:: Python print(data["collector", "traj_ids"]) .. rst-class:: sphx-glr-script-out .. code-block:: none tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9]) .. GENERATED FROM PYTHON SOURCE LINES 74-77 Multiple trajectory IDs appear because several short episodes were packed into a single 200-frame batch. The ``("next", "done")`` key marks where each episode ends: .. GENERATED FROM PYTHON SOURCE LINES 77-80 .. code-block:: Python print(data["next", "done"].squeeze(-1)) .. rst-class:: sphx-glr-script-out .. code-block:: none tensor([False, False, False, False, False, False, False, False, False, False, True, False, False, False, False, False, False, False, False, False, False, False, True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, False, False, False, False, False, False, False, False, False, False, False, True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, False, False, False, False, False, False, False, False, False, False, False, False, True, False, False, False, False, False, False, False, False, False, False]) .. GENERATED FROM PYTHON SOURCE LINES 81-82 We can count how many complete episodes fell within this batch: .. GENERATED FROM PYTHON SOURCE LINES 82-86 .. code-block:: Python n_episodes = data["next", "done"].sum().item() print(f"This batch of {data.shape[0]} frames contains {n_episodes} episodes.") .. rst-class:: sphx-glr-script-out .. code-block:: none This batch of 200 frames contains 9 episodes. .. GENERATED FROM PYTHON SOURCE LINES 87-95 Reassembling trajectories with ``split_trajectories`` ----------------------------------------------------- For many algorithms (especially those involving recurrent networks or episode-level returns), you need data organized **per episode**, not as a flat interleaved stream. :func:`~torchrl.collectors.utils.split_trajectories` takes a flat batch with ``("collector", "traj_ids")`` and returns a zero-padded ``TensorDict`` of shape ``(num_trajectories, max_length)``. .. GENERATED FROM PYTHON SOURCE LINES 95-101 .. code-block:: Python split_data = split_trajectories(data) print(split_data) print(f"Shape: {split_data.shape} → (num_trajectories, max_episode_length)") .. rst-class:: sphx-glr-script-out .. code-block:: none TensorDict( fields={ action: Tensor(shape=torch.Size([10, 63, 2]), device=cpu, dtype=torch.int64, is_shared=False), collector: TensorDict( fields={ mask: Tensor(shape=torch.Size([10, 63]), device=cpu, dtype=torch.bool, is_shared=False), traj_ids: Tensor(shape=torch.Size([10, 63]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([10, 63]), device=None, is_shared=False), done: Tensor(shape=torch.Size([10, 63, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([10, 63, 1]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([10, 63, 4]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([10, 63, 1]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([10, 63, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([10, 63, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([10, 63]), device=None, is_shared=False), observation: Tensor(shape=torch.Size([10, 63, 4]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([10, 63, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([10, 63, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([10, 63]), device=None, is_shared=False) Shape: torch.Size([10, 63]) → (num_trajectories, max_episode_length) .. GENERATED FROM PYTHON SOURCE LINES 102-105 Because episodes have different lengths, shorter ones are padded with zeros. The ``("collector", "mask")`` key tells you which time-steps contain real data (``True``) and which are padding (``False``): .. GENERATED FROM PYTHON SOURCE LINES 105-108 .. code-block:: Python print(split_data["collector", "mask"]) .. rst-class:: sphx-glr-script-out .. code-block:: none tensor([[ True, True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True], [ True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False]]) .. GENERATED FROM PYTHON SOURCE LINES 109-112 When computing losses on this padded tensor, **always multiply by the mask** (or index with it) so that padding does not leak into your gradients. This is especially important for recurrent models. .. GENERATED FROM PYTHON SOURCE LINES 114-131 ``done`` vs ``truncated`` and the mask --------------------------------------- TorchRL distinguishes two flavours of episode termination: * ``("next", "done")`` is ``True`` whenever an episode ends, for any reason. * ``("next", "truncated")`` is ``True`` only when the episode was cut short by an external limit (a time limit, or the collector running out of frames before the environment signalled a natural end). When a trajectory is still in-flight at the edge of a batch, its last step will be ``truncated=True, done=True``. ``split_trajectories`` handles this correctly: the mask covers exactly the valid steps, and the ``done`` / ``truncated`` flags are preserved so that you can treat natural terminations and artificial truncations differently in your value-function bootstrap. .. GENERATED FROM PYTHON SOURCE LINES 131-135 .. code-block:: Python print("done shape: ", split_data["next", "done"].shape) print("truncated shape:", split_data["next", "truncated"].shape) .. rst-class:: sphx-glr-script-out .. code-block:: none done shape: torch.Size([10, 63, 1]) truncated shape: torch.Size([10, 63, 1]) .. GENERATED FROM PYTHON SOURCE LINES 136-143 Padded vs nested tensors ------------------------ By default ``split_trajectories`` zero-pads to the length of the longest trajectory. If your episodes vary a lot in length this wastes memory. Passing ``as_nested=True`` returns a :class:`~tensordict.TensorDict` backed by nested tensors instead: .. GENERATED FROM PYTHON SOURCE LINES 143-150 .. code-block:: Python padded = split_trajectories(data, as_nested=False) nested = split_trajectories(data, as_nested=True) print(f"Padded shape : {padded.shape}") print(f"Nested result: {type(nested).__name__}, batch_size={nested.batch_size}") .. rst-class:: sphx-glr-script-out .. code-block:: none Padded shape : torch.Size([10, 63]) Nested result: TensorDict, batch_size=torch.Size([10, -1]) .. GENERATED FROM PYTHON SOURCE LINES 151-154 **Recommendation:** use the default (padded) for simplicity and broad compatibility. Switch to ``as_nested=True`` when episode lengths are highly variable and memory is a concern. .. GENERATED FROM PYTHON SOURCE LINES 156-164 Getting complete trajectories with ``trajs_per_batch`` ------------------------------------------------------- Sometimes you want the collector itself to hand you **complete episodes** rather than fixed-frame chunks. The ``trajs_per_batch`` argument tells the collector to buffer partial trajectories internally and yield only once it has accumulated the requested number of finished episodes. .. GENERATED FROM PYTHON SOURCE LINES 164-178 .. code-block:: Python collector_trajs = SyncDataCollector( env, policy, frames_per_batch=200, total_frames=-1, trajs_per_batch=5, ) for traj_data in collector_trajs: print(traj_data) break print(f"Shape: {traj_data.shape} → (trajs_per_batch, max_episode_length)") .. rst-class:: sphx-glr-script-out .. code-block:: none TensorDict( fields={ action: Tensor(shape=torch.Size([5, 55, 2]), device=cpu, dtype=torch.int64, is_shared=False), collector: TensorDict( fields={ mask: Tensor(shape=torch.Size([5, 55]), device=cpu, dtype=torch.bool, is_shared=False), traj_ids: Tensor(shape=torch.Size([5, 55]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([5, 55]), device=None, is_shared=False), done: Tensor(shape=torch.Size([5, 55, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([5, 55, 1]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([5, 55, 4]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([5, 55, 1]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([5, 55, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([5, 55, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([5, 55]), device=None, is_shared=False), observation: Tensor(shape=torch.Size([5, 55, 4]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([5, 55, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([5, 55, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([5, 55]), device=None, is_shared=False) Shape: torch.Size([5, 55]) → (trajs_per_batch, max_episode_length) .. GENERATED FROM PYTHON SOURCE LINES 179-182 Every row is a **complete** episode. The mask confirms this — each trajectory starts at step 0 and runs until the episode's natural (or truncated) end: .. GENERATED FROM PYTHON SOURCE LINES 182-185 .. code-block:: Python print(traj_data["collector", "mask"]) .. rst-class:: sphx-glr-script-out .. code-block:: none tensor([[ True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True], [ True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False]]) .. GENERATED FROM PYTHON SOURCE LINES 186-204 Storing transitions and sampling trajectory slices --------------------------------------------------- In off-policy training the standard pattern is to store **flat transitions** in a :class:`~torchrl.data.ReplayBuffer` and let a :class:`~torchrl.data.SliceSampler` carve out contiguous sub-sequences that respect episode boundaries. The sampler uses ``("next", "done")`` to locate where episodes end, so you never get a slice that straddles two unrelated trajectories. This is the approach used in the :ref:`Recurrent DQN tutorial `. .. seealso:: The :ref:`replay buffer tutorial ` covers trajectory storage in more depth, including alternative samplers such as :class:`~torchrl.data.PrioritizedSliceSampler` and :class:`~torchrl.data.SliceSamplerWithoutReplacement`. .. GENERATED FROM PYTHON SOURCE LINES 204-216 .. code-block:: Python from torchrl.data import SliceSampler rb = ReplayBuffer( storage=LazyTensorStorage(max_size=10_000), sampler=SliceSampler( slice_len=16, end_key=("next", "done"), ), batch_size=32, ) .. GENERATED FROM PYTHON SOURCE LINES 217-221 We extend the buffer with the **flat** collector batch (``data``, shape ``(200,)``), not with the pre-assembled trajectory tensor. The ``SliceSampler`` reads the ``("next", "done")`` flags in this flat storage to figure out where episodes start and stop. .. GENERATED FROM PYTHON SOURCE LINES 221-229 .. code-block:: Python rb.extend(data) print(f"Buffer length after one batch: {len(rb)}") sample = rb.sample() print(sample) .. rst-class:: sphx-glr-script-out .. code-block:: none Buffer length after one batch: 200 TensorDict( fields={ action: Tensor(shape=torch.Size([32, 2]), device=cpu, dtype=torch.int64, is_shared=False), collector: TensorDict( fields={ traj_ids: Tensor(shape=torch.Size([32]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([32]), device=cpu, is_shared=False), done: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([32, 4]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([32]), device=cpu, is_shared=False), observation: Tensor(shape=torch.Size([32, 4]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([32]), device=cpu, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 230-232 With ``batch_size=32`` and ``slice_len=16`` the sampler must draw exactly ``32 // 16 = 2`` contiguous trajectory slices per call: .. GENERATED FROM PYTHON SOURCE LINES 232-236 .. code-block:: Python traj_ids = sample["collector", "traj_ids"] print(f"Unique trajectories in sample: {traj_ids.unique().numel()}") .. rst-class:: sphx-glr-script-out .. code-block:: none Unique trajectories in sample: 2 .. GENERATED FROM PYTHON SOURCE LINES 237-256 Each sampled batch contains contiguous slices of 16 steps drawn from the stored transitions. A typical training loop looks like this: .. code-block:: python collector = SyncDataCollector(env, policy, frames_per_batch=200, ...) rb = ReplayBuffer( storage=LazyTensorStorage(max_size=100_000), sampler=SliceSampler(slice_len=16, end_key=("next", "done")), batch_size=64, ) for batch in collector: rb.extend(batch) for _ in range(n_optim): sample = rb.sample() loss = loss_fn(sample) loss.backward() optim.step() .. GENERATED FROM PYTHON SOURCE LINES 258-266 Asynchronous collection with ``collector.start()`` --------------------------------------------------- When a replay buffer is passed directly to the collector, you can decouple collection from training entirely using :meth:`~torchrl.collectors.Collector.start`. The collector runs in a background thread and writes flat transitions into the buffer continuously while your training loop samples from it. .. GENERATED FROM PYTHON SOURCE LINES 266-290 .. code-block:: Python import time from torchrl.collectors import Collector rb_async = ReplayBuffer( storage=LazyTensorStorage(max_size=10_000), sampler=SliceSampler( slice_len=16, end_key=("next", "done"), ), shared=True, ) collector_async = Collector( env, policy, replay_buffer=rb_async, frames_per_batch=200, total_frames=-1, ) collector_async.start() .. GENERATED FROM PYTHON SOURCE LINES 291-294 The collector is now filling ``rb_async`` in the background with flat transitions. The ``SliceSampler`` will carve contiguous 16-step slices out of this flat storage, respecting episode boundaries. .. GENERATED FROM PYTHON SOURCE LINES 294-306 .. code-block:: Python for _ in range(10): time.sleep(0.1) if len(rb_async) > 0: break print(f"Buffer length after background collection: {len(rb_async)}") if len(rb_async) >= 16: sample = rb_async.sample(batch_size=32) print(sample) .. rst-class:: sphx-glr-script-out .. code-block:: none Buffer length after background collection: 200 TensorDict( fields={ action: Tensor(shape=torch.Size([32, 2]), device=cpu, dtype=torch.int64, is_shared=False), collector: TensorDict( fields={ traj_ids: Tensor(shape=torch.Size([32]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([32]), device=cpu, is_shared=False), done: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([32, 4]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([32]), device=cpu, is_shared=False), observation: Tensor(shape=torch.Size([32, 4]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([32]), device=cpu, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 307-308 When you are done, shut the collector down: .. GENERATED FROM PYTHON SOURCE LINES 308-311 .. code-block:: Python collector_async.async_shutdown() .. GENERATED FROM PYTHON SOURCE LINES 312-315 This pattern is especially useful when environment stepping is slow (e.g. physics simulators or LLM inference): the training loop never idles waiting for new data, and the buffer is always fresh. .. GENERATED FROM PYTHON SOURCE LINES 317-345 Conclusion ---------- In this tutorial we covered how TorchRL collectors handle trajectories: * Collectors return **fixed-size batches** that interleave fragments of multiple episodes. * :func:`~torchrl.collectors.utils.split_trajectories` reassembles them into a ``(num_trajectories, max_length)`` padded tensor with a mask. * ``done`` marks any episode end; ``truncated`` flags artificial cut-offs. The mask covers valid time-steps only. * ``as_nested=True`` gives memory-efficient ragged tensors. * ``trajs_per_batch`` makes the collector yield complete episodes directly. * Complete episodes slot naturally into a :class:`~torchrl.data.ReplayBuffer`. * Passing a replay buffer and calling :meth:`~torchrl.collectors.Collector.start` enables fully asynchronous collection in a background thread. Useful next resources ~~~~~~~~~~~~~~~~~~~~~ * :ref:`Get started with data collection ` — basic collector and replay-buffer workflow. * :ref:`Recurrent DQN tutorial ` — training a recurrent policy where per-episode data is essential. * `TorchRL documentation `_ .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.282 seconds) .. _sphx_glr_download_tutorials_collector_trajectory_assembly.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: collector_trajectory_assembly.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: collector_trajectory_assembly.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: collector_trajectory_assembly.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_