Shortcuts

Single Node Collectors

TorchRL provides several collector classes for single-node data collection, each with different execution strategies.

Single node data collectors

BaseCollector()

Base class for data collectors.

Collector(create_env_fn[, policy, ...])

Generic data collector for RL problems.

AsyncCollector(*args[, sync])

Runs a single DataCollector on a separate process.

AsyncBatchedCollector(create_env_fn, *[, ...])

Asynchronous collector that pairs per-env threads with an AsyncEnvPool and an InferenceServer.

MultiCollector(*args[, sync])

Runs a given number of DataCollectors on separate processes.

MultiSyncCollector(*args[, sync])

Runs a given number of DataCollectors on separate processes synchronously.

MultiAsyncCollector(*args[, sync])

Runs a given number of DataCollectors on separate processes asynchronously.

Trajectory batching

Pass trajs_per_batch=N to any collector to receive batches of exactly N complete, zero-padded trajectories instead of fixed-frame batches. Trajectories that span multiple internal collection steps are automatically reassembled. Each yielded TensorDict has shape (N, max_traj_len) and includes a ("collector", "mask") boolean tensor marking valid time steps.

frames_per_batch still controls how frequently the environment is polled internally; it does not determine the output batch size when trajs_per_batch is set.

from torchrl.collectors import Collector
from torchrl.envs import GymEnv

collector = Collector(
    GymEnv("CartPole-v1"),
    policy=my_policy,
    frames_per_batch=200,  # controls internal polling frequency
    total_frames=10000,
    trajs_per_batch=4,
)

for batch in collector:
    # batch.shape == (4, max_traj_len)
    valid = batch[("collector", "mask")]  # (4, max_traj_len) bool
    loss = compute_loss(batch, valid)
    collector.update_policy_weights_()

Replay buffer integration: when a replay_buffer is also provided, complete trajectories are written to the buffer as flat 1-D sequences (no padding) instead of being yielded. This is the recommended pattern for off-policy training with SliceSampler, especially with multi-process collectors where fixed-frame batches can silently mix episodes. See Complete trajectory collection with trajs_per_batch for full details and examples.

Note

The following legacy names are also available for backward compatibility:

  • DataCollectorBaseBaseCollector

  • SyncDataCollectorCollector

  • aSyncDataCollectorAsyncCollector

  • _MultiDataCollectorMultiCollector

  • MultiSyncDataCollectorMultiSyncCollector

  • MultiaSyncDataCollectorMultiAsyncCollector

Using AsyncBatchedCollector

The AsyncBatchedCollector pairs an AsyncEnvPool with an InferenceServer to pipeline environment stepping and batched GPU inference. You only need to supply env factories and a policy – all internal wiring is handled automatically:

from torchrl.collectors import AsyncBatchedCollector
from torchrl.envs import GymEnv
from tensordict.nn import TensorDictModule
import torch.nn as nn

policy = TensorDictModule(
    nn.Sequential(nn.Linear(4, 64), nn.ReLU(), nn.Linear(64, 2)),
    in_keys=["observation"],
    out_keys=["action"],
)

collector = AsyncBatchedCollector(
    create_env_fn=[lambda: GymEnv("CartPole-v1")] * 8,
    policy=policy,
    frames_per_batch=200,
    total_frames=10000,
    max_batch_size=8,
)

for data in collector:
    # data is a lazy-stacked TensorDict of collected transitions
    pass

collector.shutdown()

Key advantages over Collector:

  • The inference server automatically batches policy forward passes from all environments, maximising GPU utilisation.

  • Environment stepping and inference run in overlapping fashion, reducing idle time.

  • Supports yield_completed_trajectories=True for episode-level yields.

Using MultiCollector

The MultiCollector class is the recommended way to run parallel data collection. It uses a sync parameter to dispatch to either MultiSyncCollector or MultiAsyncCollector:

from torchrl.collectors import MultiCollector
from torchrl.envs import GymEnv

def make_env():
    return GymEnv("CartPole-v1")

# Synchronous multi-worker collection (recommended for on-policy algorithms)
sync_collector = MultiCollector(
    create_env_fn=[make_env] * 4,  # 4 parallel workers
    policy=my_policy,
    frames_per_batch=1000,
    total_frames=100000,
    sync=True,  # ← All workers complete before delivering batch
)

# Asynchronous multi-worker collection (recommended for off-policy algorithms)
async_collector = MultiCollector(
    create_env_fn=[make_env] * 4,
    policy=my_policy,
    frames_per_batch=1000,
    total_frames=100000,
    sync=False,  # ← First-come-first-serve delivery
)

# Iterate over collected data
for data in sync_collector:
    # Train on data...
    pass

sync_collector.shutdown()

Comparison:

Feature

sync=True

sync=False

Batch delivery

All workers complete first

First available worker

Policy consistency

All data from same policy version

Data may be from older policy

Best for

On-policy (PPO, A2C)

Off-policy (SAC, DQN)

Throughput

Limited by slowest worker

Higher throughput

Running the Collector Asynchronously

Passing replay buffers to a collector allows us to start the collection and get rid of the iterative nature of the collector. If you want to run a data collector in the background, simply run start():

>>> collector = Collector(..., replay_buffer=rb) # pass your replay buffer
>>> collector.start()
>>> # little pause
>>> time.sleep(10)
>>> # Start training
>>> for i in range(optim_steps):
...     data = rb.sample()  # Sampling from the replay buffer
...     # rest of the training loop

Single-process collectors (Collector) will run the process using multithreading, so be mindful of Python’s GIL and related multithreading restrictions.

Multiprocessed collectors will on the other hand let the child processes handle the filling of the buffer on their own, which truly decouples the data collection and training.

Data collectors that have been started with start() should be shut down using async_shutdown().

Tip

For maximum throughput with trajectory-based training (e.g. recurrent policies, decision transformers), combine start() with trajs_per_batch and a SliceSampler:

rb = ReplayBuffer(
    storage=LazyTensorStorage(100_000),
    sampler=SliceSampler(slice_len=32, end_key=("next", "done")),
    batch_size=256,
    shared=True,
)
collector = MultiCollector(
    [make_env] * 4,
    policy,
    replay_buffer=rb,
    frames_per_batch=200,
    total_frames=-1,
    trajs_per_batch=8,
    sync=False,
)
collector.start()
for step in range(train_steps):
    batch = rb.sample()  # clean trajectory slices
    # ...
collector.async_shutdown()

Each worker writes only complete trajectories to the buffer, so the sampler never draws slices that cross episode boundaries. See Complete trajectory collection with trajs_per_batch for a full discussion.

Warning

Running a collector asynchronously decouples the collection from training, which means that the training performance may be drastically different depending on the hardware, load and other factors (although it is generally expected to provide significant speed-ups). Make sure you understand how this may affect your algorithm and if it is a legitimate thing to do! (For example, on-policy algorithms such as PPO should not be run asynchronously unless properly benchmarked).

Docs

Lorem ipsum dolor sit amet, consectetur

View Docs

Tutorials

Lorem ipsum dolor sit amet, consectetur

View Tutorials

Resources

Lorem ipsum dolor sit amet, consectetur

View Resources