Collector Internals#
This page describes how Collector steps through an
environment. It is meant for
contributors and for users debugging unexpected rollout behaviour: device
casts, per-step bookkeeping, and trajectory tracking are implementation details
that are not visible from the public API.
The multi-process collectors (MultiSyncCollector and
MultiAsyncCollector) delegate their per-worker rollouts to
Collector, so the per-worker flow on this page applies to
them too.
Per-timestep flow#
A single iteration of Collector.rollout() corresponds to one
environment step. frames_per_batch such iterations are stacked into the
batch yielded to the user, extended into a replay buffer, or written directly
with replay_buffer.add(...) when direct replay-buffer writes are enabled.
┌─────────────────────────────────────────────────────────────────────┐
│ for t in range(frames_per_batch): │
│ │
│ ┌─ carrier ──────────────────────────────────────────────────┐ │
│ │ TensorDict — observation + collector metadata; │ │
│ │ device-cleared when needed for cross-device stepping │ │
│ └────────────┬───────────────────────────────────────────────┘ │
│ │ │
│ │ (1) cast to policy_device if needed │
│ │ → _sync_policy() │
│ ▼ │
│ ┌──────────┐ │
│ │ policy │ ← reads obs, writes action + log_prob │
│ └────┬─────┘ │
│ │ │
│ │ carrier.update(policy_output) │
│ ▼ │
│ ┌─ carrier (now has action) ──────────────────────────────┐ │
│ └────────────┬────────────────────────────────────────────┘ │
│ │ │
│ │ (2) cast to env_device if needed │
│ │ → _sync_env() │
│ ▼ │
│ ┌──────────────┐ │
│ │ env.step_ │ ← returns (env_output, env_next_output)│
│ │ and_maybe_ │ auto-resets done envs │
│ │ reset │ │
│ └────┬─────────┘ │
│ │ │
│ │ carrier.set("next", env_output["next"]) │
│ ▼ │
│ ┌─ carrier_for_out (snapshot for this step) ──────────────┐ │
│ └────────────┬────────────────────────────────────────────┘ │
│ │ │
│ │ (3a) replay_buffer.add(carrier_for_out) │
│ │ for direct writes │
│ │ │
│ │ (3b) otherwise cast to storing_device │
│ │ if needed → _sync_storage(), then append │
│ ▼ │
│ direct replay-buffer write OR append to tensordicts │
│ │ │
│ │ carrier = env_next_output (post-reset state) │
│ │ update traj_ids if any env finished │
│ └─→ next iteration │
└─────────────────────────────────────────────────────────────────────┘
Implementation: Collector.rollout() in
torchrl/collectors/_single.py.
The carrier#
The carrier is the TensorDictBase instance stored as
self._carrier. It persists across calls to next(iter(collector)) and
holds the post-reset result of the previous environment step, which is the
state that the next policy call must consume. It is initialized by
Collector._make_carrier() and then advanced at the end of every
timestep by assigning env_next_output back to self._carrier.
Why it exists#
State persistence across batches. Collection may stop at a batch boundary while the environment trajectory continues. The carrier preserves the latest reset-aware environment output so the next rollout resumes from the correct observation and recurrent state.
Allocation amortization. Reusing the same tensordict-shaped state avoids allocating a fresh container for every policy/env exchange.
Device-neutral handoff. When the policy and environment cannot share a single device-owned tensordict, the carrier is cleared of its device with
clear_device_(). The boolean flagself._carrier_has_no_devicerecords whether this invariant must be preserved when new"next"data is merged.Collector metadata. Trajectory IDs and other
("collector", ...)keys live on the carrier and persist across steps without round-tripping through the env.
Reading the carrier#
You should not normally touch self._carrier directly; it is an
implementation detail. If you need to instrument collected data, use
Collector.post_collect_hook or read the data yielded by
iteration. Mutating the carrier from a hook is undefined behaviour.
Sync points#
Three explicit synchronisation functions are installed at construction time in
Collector._setup_devices() and called inside the rollout loop
when the corresponding explicit sync is not disabled by no_cuda_sync=True:
_sync_policyCalled after copying the carrier to
policy_deviceand before the policy reads it._sync_envCalled after copying the carrier to
env_deviceand before the environment reads it._sync_storageCalled after copying
carrier_for_outtostoring_deviceon the append-to-list path. The directreplay_buffer.add(...)path does not perform this cast or sync.
What _sync_* actually is depends on the destination device; see
Collector._get_sync_fn():
Destination |
Sync function |
|---|---|
|
|
Non-CUDA, CUDA available |
|
Non-CUDA, MPS available |
|
Non-CUDA, NPU available |
|
|
|
|
|
Setting no_cuda_sync=True on the collector skips the explicit _sync_*
calls. Only do this if you know the transfers are already correctly ordered or
if you are running pure CPU.
Device casting flags#
Two cached booleans short-circuit the per-step device logic:
_cast_to_policy_deviceSet in
Collector._setup_devices().Trueiffpolicy_device != env_device. WhenTrue, the carrier is copied topolicy_devicebefore the policy is invoked._cast_to_env_deviceSet in
Collector._apply_env_device(), after the environment device has been applied or inferred. It isTruewhen_cast_to_policy_deviceis alreadyTrueor whenenv.device != storing_device. WhenTrue, the carrier is copied toenv_devicebeforeenv.step_and_maybe_reset.
These are computed once so that the per-step branches degenerate into a single bool check when everything lives on the same device.
The companion flag _carrier_has_no_device (set in
Collector._make_carrier()) records whether the carrier was
stripped of its device. When True, any new "next" data merged into the
carrier after an env step is also device-stripped so the deviceless invariant
is preserved.
Trajectory IDs#
When track_traj_ids=True (the default), every frame carries a
("collector", "traj_ids") integer that uniquely identifies the trajectory
it belongs to. Two pieces of machinery cooperate:
Collector._traj_pool()returns a process-localtorchrl.collectors.utils._TrajectoryPoolthat hands out fresh IDs. In multi-process collectors, workers share a locked pool created by the parent collector so IDs do not collide across worker resets.Collector._update_traj_ids()runs after each env step. It reads the aggregated end-of-trajectory signal from("next", "done")via_aggregate_end_of_traj(), draws as many fresh IDs from the pool as there are envs that finished, andmasked_scatter-s them into the per-envtraj_idstensor on the carrier.
Setting track_traj_ids=False skips both the per-step bookkeeping and the
allocation of the traj_ids tensor. This is useful in throughput-sensitive
setups that do not need trajectory-aware sampling. Note that
split_trajs=True requires track_traj_ids=True; the constructor will
raise if you ask for the former without the latter.
Collection hooks#
Two opt-in callbacks let you instrument collection without subclassing:
pre_collect_hookCalled once at the top of
rollout(), before the per-timestep loop starts and before anyreset_at_each_iterreset. Receives no arguments. Use it to step a profiler, mark a section in NVTX, or update a worker-local counter.post_collect_hookCalled with the batch tensordict immediately before it is yielded to the consumer. Receives the
TensorDictBasethat will be yielded. Return value is ignored. Use it to log metrics derived from the batch.
Hooks are worker-local: in MultiSyncCollector /
MultiAsyncCollector they run inside each worker process, not on the
training worker. Exceptions raised by a hook propagate up and stop collection;
they are not swallowed.
For batch transformations (rather than instrumentation), use postproc on
the collector constructor instead.
Where to look in the code#
Concept |
File / function |
|---|---|
Per-step rollout loop |
|
Carrier initialization |
|
Device setup and policy cast flag |
|
Environment device application and env cast flag |
|
Sync function dispatch |
|
Trajectory ID update |
|
Trajectory ID pool |
|
Hooks |
|
See also#
Collector Basics for the high-level API
Profiling collectors and envs for
TORCHRL_PROFILING=1instrumentation that emits named ranges on the carrier / policy / env transitions described aboveData layout: contiguous trajectories for the shape and key conventions the carrier follows