Evaluator¶
- class torchrl.collectors.Evaluator(env: EnvBase | Callable[[], EnvBase], policy: TensorDictModuleBase | Callable | None = None, *, policy_factory: Callable[[...], Callable] | None = None, num_trajectories: int = 10, max_steps: int | None = None, frames_per_batch: int | None = None, collector_cls: type | str | None = None, collector_kwargs: dict | None = None, weight_sync_schemes: dict[str, Any] | None = None, log_prefix: str = 'eval', reward_keys: str | tuple[str, ...] = ('next', 'reward'), done_keys: str | tuple[str, ...] = ('next', 'done'), device: device | str | None = None, exploration_type: InteractionType = InteractionType.DETERMINISTIC, metrics_fn: Callable[[TensorDictBase], dict[str, float]] | None = None, dump_video: bool = True, on_result: Callable[[TensorDictBase], None] | None = None, busy_policy: str = 'error', backend: str = 'thread', init_fn: Callable[[], None] | None = None, num_gpus: int = 1, ray_kwargs: dict | None = None)[source]¶
Unified sync / async evaluator with pluggable backend.
The evaluator wraps an environment and a policy and provides two modes of operation:
Synchronous – call
evaluate()to run a blocking evaluation and get metrics back immediately.Asynchronous – call
trigger_eval()to kick off an evaluation in the background, thenpoll()(non-blocking) orwait()(blocking) to retrieve the result. Use thependingproperty to check whether an evaluation is currently in progress. Results can also be consumed via anon_resultcallback.
Internally, a
Collectoris used withtrajs_per_batch=num_trajectoriesto collect complete episodes. The collector pre-allocates buffers and writes in-place — O(1) GPU allocations vs O(n) per step — yielding significant speedups for batched eval environments.Three backends are available:
"thread"(default) – runs in a daemon thread. Low overhead, well suited for GPU-bound evaluation where the GIL is released by CUDA ops. When env is a callable and policy_factory is provided, both are created lazily inside the worker thread, which is useful for dedicated eval devices."process"– runs in a child process (spawncontext). The env and policy are always created inside the child process, giving full CUDA context isolation and avoiding the GIL entirely. Requires env to be a callable and policy_factory to be provided."ray"– runs in a Ray actor, suitable for distributed setups. Requires env to be a callable and policy_factory to be provided.
Backpressure / overlap policy: calling
trigger_eval()while a previous evaluation is still running either raises immediately (busy_policy="error"; default) or queues the new request (busy_policy="queue"). Usependingto conditionally skip trigger calls:if not evaluator.pending: evaluator.trigger_eval(weights, step=step)
Callback thread-safety: when
on_resultis provided, it is invoked from the evaluator’s async coordination thread after the rollout completes. If the callback writes to a logger, the callback is responsible for any locking it needs.Dedicated eval device (multi-GPU example):
evaluator = Evaluator( lambda: make_env(device="cuda:7"), policy_factory=lambda env: make_policy(env).to("cuda:7"), max_steps=1000, backend="process", # or "thread" )
Batched eval environments: for best results, add a
RewardSumtransform to the eval env so that per-episode returns are tracked. Without it, the evaluator falls back to summing raw rewards over each trajectory.- Parameters:
env – An
EnvBaseinstance or a callable that returns one. For the"process"and"ray"backends the callable form is required. For the"thread"backend, when combined with policy_factory, passing a callable defers construction to the worker thread.policy – The evaluation policy. Mutually exclusive with policy_factory.
- Keyword Arguments:
policy_factory – A callable
(env) -> policyused to build the policy. Required for the"process"and"ray"backends. For"thread", if both env (callable) and policy_factory are provided, construction is deferred to the worker thread.num_trajectories (int) – Number of complete episodes to collect per evaluation round. A
Collectoris used internally withtrajs_per_batch=num_trajectories. Default:10.max_steps (int or None) – Maximum environment steps per episode, passed as
max_frames_per_trajto the internal collector. WhenNone, episodes run until done with no step limit. Default:None.frames_per_batch (int or None) – Internal collection batch size (env steps per collector iteration). If
None, defaults tomax_steps. This is purely internal — output granularity is controlled by num_trajectories.collector_cls – Which collector class to use. Accepts a class or a string name resolved from
torchrl.collectors(e.g."Collector"). Default:None(usesCollector).collector_kwargs (dict or None) – Extra keyword arguments forwarded to the collector constructor.
log_prefix (str) – Prefix prepended to all logged metric names. Default:
"eval".reward_keys – Nested key(s) for reading the reward from the tensordict. Default:
("next", "reward").done_keys – Nested key(s) for reading the done flag. Default:
("next", "done").device – Device for the evaluation policy. If
None, inferred from the policy parameters.exploration_type – Exploration mode during evaluation. Default:
ExplorationType.DETERMINISTIC.metrics_fn – Optional
(TensorDictBase) -> dict[str, float]called on every trajectory batch to extract custom metrics.dump_video (bool) – Call
dump()onVideoRecordertransforms after each evaluation (thread backend only). Default:True.on_result – Optional
(TensorDictBase) -> Noneinvoked after each completed evaluation. The callback receives a flat tensordict with the same prefixed metric names returned byevaluate(),poll(), andwait().busy_policy (str) –
Behaviour when
trigger_eval()is called while another async evaluation is still pending."error"raises immediately (default; recommended)."queue"enqueues the new request and runs it when the current evaluation finishes.Warning
With
busy_policy="queue", each queued request stores a copy of the weights dict. For large models this can consume significant memory. Prefer checkingpendingand skipping triggers instead.weight_sync_schemes (dict or None) –
A dict mapping model IDs to
WeightSyncSchemeinstances. When provided, aMultiSyncCollectorwith a single worker is used for process-level CUDA isolation and scheme-based weight transfer. Model IDs follow the collector convention:"policy"for the main policy,"env.transform[0]"for env transforms, etc. Example:from torchrl.weight_update import MultiProcessedWeightSyncScheme evaluator = Evaluator( env=make_eval_env, policy_factory=make_eval_policy, weight_sync_schemes={ "policy": MultiProcessedWeightSyncScheme(), "env.transform[0]": MultiProcessedWeightSyncScheme(), }, max_steps=1000, )
backend (str) –
"thread"(default),"process", or"ray". The"process"backend is implemented as a thread backend with aMultiSyncCollector(1 worker) running in a child process. This provides full CUDA context isolation without custom queue management.init_fn – Callable invoked at the start of the worker / actor process, before any evaluation work (and, ideally, before any
torchimport inside that process). Used by both the"process"and"ray"backends. Typical use case: start Isaac Lab’sAppLauncherin headless mode. Ignored by the"thread"backend because no new process is spawned.num_gpus (int) – (Ray only) GPUs requested for the actor. Default:
1.ray_kwargs (dict) – (Ray only) Extra keyword arguments forwarded to
ray.remote().
- evaluate(weights: TensorDictBase | Module | None = None, step: int | None = None, *, weights_dict: dict[str, TensorDictBase | Module] | None = None) dict[str, Any][source]¶
Run a blocking evaluation rollout.
- Parameters:
weights – Policy weights to load before the rollout. Accepts a
TensorDictBase(e.g. fromTensorDict.from_module(policy).data) or annn.Module(weights are extracted automatically). IfNonethe current policy weights are used.step – Logging step. If
Nonean internal counter is used.
- Keyword Arguments:
weights_dict – A dict mapping
model_idstrings to weight sources (nn.ModuleorTensorDictBase). Use this to sync multiple models (e.g. policy + env transforms). When provided, weights is treated asweights_dict["policy"]if"policy"is not already in the dict.- Returns:
dict with at least
"<prefix>/reward"and"<prefix>/episode_length"keys.
- static extract_weights(policy: Module) TensorDictBase[source]¶
Extract detached, cloned, CPU weights from a policy.
This is a convenience helper; the returned TensorDict is safe to pass across threads.
- property pending: bool¶
Return
Trueif an async evaluation is currently in progress.This can be used to avoid triggering overlapping evaluations:
if not evaluator.pending: evaluator.trigger_eval(weights, step=step)
- poll(timeout: float = 0) dict[str, Any] | None[source]¶
Return the latest evaluation result if ready, else
None.- Parameters:
timeout – Seconds to wait.
0means non-blocking.
- trigger_eval(weights: TensorDictBase | Module | None = None, step: int | None = None, *, weights_dict: dict[str, TensorDictBase | Module] | None = None) None[source]¶
Start an async evaluation.
- Parameters:
weights – Policy weights to load. See
evaluate().step – Logging step. See
evaluate().weights_dict – Multi-model weights dict. See
evaluate().