.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorials/evaluator.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials_evaluator.py: Using the Evaluator =================== **Author**: `Vincent Moens `_ .. _evaluator_tuto: .. grid:: 2 .. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn * How to run synchronous and asynchronous evaluations during training * How to pass updated weights to the evaluator * How to use the ``on_result`` callback for logging * How to run evaluation in a separate process .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites * `TorchRL `_ and `gymnasium `_ installed * Familiarity with :class:`~torchrl.envs.EnvBase` and :class:`~torchrl.collectors.Collector` .. GENERATED FROM PYTHON SOURCE LINES 25-27 .. code-block:: Python :dedent: 1 .. GENERATED FROM PYTHON SOURCE LINES 33-51 In RL training loops, evaluation is often done inline: you stop training, run a few rollouts, log the metrics, then resume. This blocks the training loop while rollouts are collected. For environments with expensive step functions (robotics simulators, LLM generation, etc.), this can waste significant GPU time. The :class:`~torchrl.collectors.Evaluator` decouples evaluation from training by running rollouts in the background and letting you poll for metrics or react to results via a callback. In this tutorial we will cover: 1. :ref:`Synchronous evaluation ` — blocking calls 2. :ref:`Asynchronous evaluation ` — fire-and-poll 3. :ref:`Weight updates ` — passing trained weights 4. :ref:`Process-based evaluation ` — out-of-process 5. :ref:`Logging with callbacks ` — ``on_result`` .. GENERATED FROM PYTHON SOURCE LINES 51-61 .. code-block:: Python from functools import partial import torch from tensordict import from_module from tensordict.nn import TensorDictModule from torch import nn from torchrl.collectors import Evaluator, RandomPolicy from torchrl.envs import GymEnv .. GENERATED FROM PYTHON SOURCE LINES 62-75 Synchronous evaluation ---------------------- .. _tuto_eval_sync: The simplest way to use the Evaluator is to call :meth:`~torchrl.collectors.Evaluator.evaluate`, which blocks until the rollout completes and returns a metrics dict. We start by creating an environment factory and a random policy. The Evaluator accepts either a live environment or a callable that creates one — the callable form is preferred because it lets the evaluator recreate the environment if needed. .. GENERATED FROM PYTHON SOURCE LINES 75-80 .. code-block:: Python env_maker = partial(GymEnv, "Pendulum-v1") policy = RandomPolicy(env_maker().action_spec) evaluator = Evaluator(env_maker, policy, num_trajectories=1) .. GENERATED FROM PYTHON SOURCE LINES 81-84 Now we can run a blocking evaluation. The returned dict contains prefixed metrics: reward, episode length, number of episodes, and frames-per-second. .. GENERATED FROM PYTHON SOURCE LINES 84-88 .. code-block:: Python result = evaluator.evaluate() print("First eval:", result) .. rst-class:: sphx-glr-script-out .. code-block:: none First eval: {'eval/reward': -1641.96826171875, 'eval/reward_std': 0.0, 'eval/num_episodes': 1, 'eval/episode_length': 200.0, 'eval/fps': 708.2942105549897, 'eval/step': 0} .. GENERATED FROM PYTHON SOURCE LINES 89-90 Each subsequent call increments the internal step counter: .. GENERATED FROM PYTHON SOURCE LINES 90-94 .. code-block:: Python result = evaluator.evaluate() print("Second eval:", result) .. rst-class:: sphx-glr-script-out .. code-block:: none Second eval: {'eval/reward': -1682.063720703125, 'eval/reward_std': 0.0, 'eval/num_episodes': 1, 'eval/episode_length': 200.0, 'eval/fps': 725.2587154674727, 'eval/step': 1} .. GENERATED FROM PYTHON SOURCE LINES 95-103 Asynchronous evaluation ----------------------- .. _tuto_eval_async: For non-blocking evaluation, use :meth:`~torchrl.collectors.Evaluator.trigger_eval` to start a rollout in the background, then :meth:`~torchrl.collectors.Evaluator.poll` or :meth:`~torchrl.collectors.Evaluator.wait` to retrieve the result. .. GENERATED FROM PYTHON SOURCE LINES 103-110 .. code-block:: Python evaluator.trigger_eval() # poll() is non-blocking: returns None if the result isn't ready yet result = evaluator.poll() print("poll() returned:", result) .. rst-class:: sphx-glr-script-out .. code-block:: none poll() returned: None .. GENERATED FROM PYTHON SOURCE LINES 111-112 To wait for the result, pass a timeout to ``poll()`` or use ``wait()``: .. GENERATED FROM PYTHON SOURCE LINES 112-116 .. code-block:: Python result = evaluator.poll(timeout=30) print("poll(timeout=30) returned:", result) .. rst-class:: sphx-glr-script-out .. code-block:: none poll(timeout=30) returned: {'eval/reward': -1457.2855224609375, 'eval/reward_std': 0.0, 'eval/num_episodes': 1, 'eval/episode_length': 200.0, 'eval/fps': 721.3010574578926, 'eval/step': 2} .. GENERATED FROM PYTHON SOURCE LINES 117-120 By default, calling ``trigger_eval()`` while a previous evaluation is still pending raises an error. This prevents silently piling up stale requests: .. GENERATED FROM PYTHON SOURCE LINES 120-131 .. code-block:: Python evaluator.trigger_eval() try: evaluator.trigger_eval() except RuntimeError as e: print(f"Errored with: {e}") # Clean up evaluator.wait(timeout=30) evaluator.shutdown() .. rst-class:: sphx-glr-script-out .. code-block:: none Errored with: Evaluation already pending. Wait for completion or set busy_policy='queue'. .. GENERATED FROM PYTHON SOURCE LINES 132-146 If you prefer to enqueue evaluations, pass ``busy_policy="queue"`` when creating the evaluator. Weight updates -------------- .. _tuto_eval_weights: In a real training loop, you want to evaluate the *latest* trained weights, not the initial ones. The :meth:`~torchrl.collectors.Evaluator.evaluate` and :meth:`~torchrl.collectors.Evaluator.trigger_eval` methods accept a ``weights`` argument — either an ``nn.Module`` or a ``TensorDictBase``. Let's create a simple MLP policy and an evaluator for it: .. GENERATED FROM PYTHON SOURCE LINES 146-157 .. code-block:: Python env = env_maker() net = nn.Sequential( nn.Linear(env.observation_spec["observation"].shape[-1], 64), nn.Tanh(), nn.Linear(64, env.action_spec.shape[-1]), ) real_policy = TensorDictModule(net, in_keys=["observation"], out_keys=["action"]) evaluator_w = Evaluator(env_maker, real_policy, num_trajectories=1) .. GENERATED FROM PYTHON SOURCE LINES 158-159 Evaluate with the initial (random) weights: .. GENERATED FROM PYTHON SOURCE LINES 159-162 .. code-block:: Python print("Before weight update:", evaluator_w.evaluate()) .. rst-class:: sphx-glr-script-out .. code-block:: none Before weight update: {'eval/reward': -1413.3218994140625, 'eval/reward_std': 0.0, 'eval/num_episodes': 1, 'eval/episode_length': 200.0, 'eval/fps': 661.9042895729661, 'eval/step': 0} .. GENERATED FROM PYTHON SOURCE LINES 163-164 Simulate a "training step" by perturbing the weights: .. GENERATED FROM PYTHON SOURCE LINES 164-169 .. code-block:: Python with torch.no_grad(): for p in net.parameters(): p.add_(torch.randn_like(p) * 0.1) .. GENERATED FROM PYTHON SOURCE LINES 170-173 Now evaluate with the updated weights. You can pass the module directly — the evaluator extracts and transfers the weights automatically: .. GENERATED FROM PYTHON SOURCE LINES 173-176 .. code-block:: Python print("After weight update:", evaluator_w.evaluate(weights=real_policy)) .. rst-class:: sphx-glr-script-out .. code-block:: none After weight update: {'eval/reward': -1208.4190673828125, 'eval/reward_std': 0.0, 'eval/num_episodes': 1, 'eval/episode_length': 200.0, 'eval/fps': 675.5085238299725, 'eval/step': 1} .. GENERATED FROM PYTHON SOURCE LINES 177-179 You can also pass a ``TensorDictBase`` of weights, which is useful when you already have detached weight snapshots: .. GENERATED FROM PYTHON SOURCE LINES 179-184 .. code-block:: Python real_weights = from_module(real_policy) print("With TensorDict weights:", evaluator_w.evaluate(weights=real_weights)) evaluator_w.shutdown() .. rst-class:: sphx-glr-script-out .. code-block:: none With TensorDict weights: {'eval/reward': -1890.05029296875, 'eval/reward_std': 0.0, 'eval/num_episodes': 1, 'eval/episode_length': 200.0, 'eval/fps': 653.3084171956334, 'eval/step': 2} .. GENERATED FROM PYTHON SOURCE LINES 185-197 Process-based evaluation ------------------------ .. _tuto_eval_process: For full isolation (e.g. to place evaluation on a dedicated GPU or to avoid GIL contention), use ``backend="process"``. This runs the environment and policy inside a child process via :class:`~torchrl.collectors.MultiSyncCollector`. The process backend requires callable factories for both the environment and the policy: .. GENERATED FROM PYTHON SOURCE LINES 197-213 .. code-block:: Python env_maker = partial(GymEnv, "Pendulum-v1") action_spec = env_maker().action_spec policy_factory = partial(RandomPolicy, action_spec) evaluator_proc = Evaluator( env_maker, policy_factory=policy_factory, num_trajectories=1, backend="process", ) result = evaluator_proc.evaluate() print("Process backend:", result) evaluator_proc.shutdown() .. rst-class:: sphx-glr-script-out .. code-block:: none Process backend: {'eval/reward': -1453.0191650390625, 'eval/reward_std': 0.0, 'eval/num_episodes': 1, 'eval/episode_length': 200.0, 'eval/fps': 505.6974491606256, 'eval/step': 0} .. GENERATED FROM PYTHON SOURCE LINES 214-225 Logging with callbacks ---------------------- .. _tuto_eval_logging: Rather than manually logging after each ``poll()`` or ``wait()``, you can pass an ``on_result`` callback to the evaluator. It receives a flat :class:`~tensordict.TensorDictBase` with the same prefixed metric names. Here we use TorchRL's :class:`~torchrl.record.loggers.csv.CSVLogger` to automatically log every evaluation result to a CSV file: .. GENERATED FROM PYTHON SOURCE LINES 225-243 .. code-block:: Python import tempfile from torchrl.record.loggers.csv import CSVLogger log_dir = tempfile.mkdtemp() logger = CSVLogger(exp_name="eval_demo", log_dir=log_dir) evaluator_log = Evaluator( env_maker, real_policy, num_trajectories=1, on_result=lambda result: logger.log_metrics( {k: v.item() for k, v in result.items() if k != "eval/step"}, step=result["eval/step"].item(), ), ) .. GENERATED FROM PYTHON SOURCE LINES 244-245 Run a few evals. Each one automatically logs to CSV via the callback: .. GENERATED FROM PYTHON SOURCE LINES 245-251 .. code-block:: Python for _ in range(3): evaluator_log.evaluate(weights=real_policy) evaluator_log.shutdown() .. GENERATED FROM PYTHON SOURCE LINES 252-253 Let's verify what was logged: .. GENERATED FROM PYTHON SOURCE LINES 253-260 .. code-block:: Python from pathlib import Path csv_path = next(Path(log_dir).rglob("*.csv")) print(f"Logged to: {csv_path}") print(csv_path.read_text()) .. rst-class:: sphx-glr-script-out .. code-block:: none Logged to: /tmp/tmprby2c4z0/eval_demo/scalars/eval/reward.csv 0,-1873.7410888671875 1,-1070.4537353515625 2,-1615.29541015625 .. GENERATED FROM PYTHON SOURCE LINES 261-289 The ``on_result`` callback works with both synchronous and asynchronous evaluation. For async usage, the callback runs on the evaluator's background thread — if your callback writes to a shared logger, handle any required locking inside the callback. Summary ------- The :class:`~torchrl.collectors.Evaluator` provides a single, composable entry-point for evaluation: * **Synchronous**: :meth:`~torchrl.collectors.Evaluator.evaluate` for blocking rollouts. * **Asynchronous**: :meth:`~torchrl.collectors.Evaluator.trigger_eval` + :meth:`~torchrl.collectors.Evaluator.poll` / :meth:`~torchrl.collectors.Evaluator.wait` for background rollouts. * **Weight sync**: pass ``weights`` (module or tensordict) to evaluate the latest trained parameters. * **Process isolation**: ``backend="process"`` for dedicated-device eval. * **Callbacks**: ``on_result`` for automatic logging or checkpointing. Useful next resources ~~~~~~~~~~~~~~~~~~~~~ * :ref:`Evaluator API reference ` — full parameter docs. * :ref:`Collector trajectory tutorial ` — deep dive into how collectors assemble data. * `TorchRL documentation `_ .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 6.320 seconds) .. _sphx_glr_download_tutorials_evaluator.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: evaluator.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: evaluator.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: evaluator.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_