StalenessAwareSampler¶
- class torchrl.data.replay_buffers.StalenessAwareSampler(max_staleness: int = -1, staleness_weight_fn: callable | None = None, version_key: NestedKey = 'policy_version')[source]¶
A sampler that weights entries by freshness and filters stale entries.
This sampler is designed for asynchronous training setups (e.g., async PPO) where collected data may come from older policy versions. It reads a
policy_versionfield from the storage and uses it to:Hard gate: Exclude entries whose staleness exceeds
max_staleness.Freshness weighting: Sample proportionally to a weight that decays with staleness (default:
1 / (staleness + 1)).
The training loop is responsible for updating
consumer_version(typically after each optimizer step or weight update) so the sampler can compute staleness =consumer_version - policy_version.- Parameters:
max_staleness (int, optional) – Hard cutoff. Entries with
staleness > max_stalenessare excluded from sampling.-1(default) means no cutoff.staleness_weight_fn (callable, optional) – A callable that maps a staleness tensor (int) to a weight tensor (float). Defaults to
lambda s: 1.0 / (s.float() + 1.0).version_key (NestedKey, optional) – The key in the storage holding the policy version. Defaults to
"policy_version".
Examples
>>> from torchrl.data import TensorDictReplayBuffer, LazyTensorStorage >>> from torchrl.data.replay_buffers.samplers import StalenessAwareSampler >>> sampler = StalenessAwareSampler(max_staleness=5) >>> buffer = TensorDictReplayBuffer( ... storage=LazyTensorStorage(1000), ... sampler=sampler, ... batch_size=32, ... ) >>> # In training loop: >>> # sampler.consumer_version = current_training_step
Integration with
CollectorandPolicyVersion:from torchrl.collectors import Collector from torchrl.envs.transforms import PolicyVersion from torchrl.data import TensorDictReplayBuffer, LazyTensorStorage sampler = StalenessAwareSampler(max_staleness=10) buffer = TensorDictReplayBuffer( storage=LazyTensorStorage(10_000), sampler=sampler, batch_size=256, ) collector = Collector( env, policy, frames_per_batch=1000, total_frames=100_000, env_transforms=[PolicyVersion(collector)], ) for step, data in enumerate(collector): buffer.extend(data) sampler.consumer_version = step batch = buffer.sample() # ... train on batch ...
Note
StalenessAwareSamplerintentionally does not inherit fromPrioritizedSampler.PrioritizedSamplermaintains a segment-tree over per-transition TD-error priorities that are updated after each training step. Staleness weighting is fundamentally different: weights are derived from a single scalar (consumer_version) and per-entrypolicy_versionstamps, and are recomputed on everysample()call rather than maintained incrementally. Sharing the segment-tree machinery would add complexity without benefit.- property consumer_version: int¶
The current training iteration / consumer version.
- property max_staleness: int¶
The maximum allowed staleness. -1 means no limit.