Shortcuts

Source code for torchrl.collectors.distributed.ray_eval_worker

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Ray-based asynchronous evaluation worker.

This module provides :class:`RayEvalWorker`, a generic helper that runs an
environment and policy inside a dedicated Ray actor process.  This is useful
when the evaluation environment requires special process-level initialisation
(e.g. Isaac Lab's ``AppLauncher`` must run before ``import torch``) or when
evaluation should happen concurrently with training on a separate GPU.

Typical usage::

    from torchrl.collectors.distributed import RayEvalWorker

    worker = RayEvalWorker(
        init_fn=my_init,          # called first in the actor process
        env_maker=make_eval_env,  # returns a TorchRL env
        policy_maker=make_policy, # returns a TorchRL policy module
        num_gpus=1,
        name="my_eval_worker",    # optional: allows others to connect
    )

    # Non-blocking: submit weights and start a rollout
    weights = TensorDict.from_module(policy).data.detach().cpu()
    worker.submit(weights, max_steps=500)

    # Later -- check if the rollout finished
    result = worker.poll()       # None while still running
    if result is not None:
        print(result["reward"])  # scalar mean episode reward
        print(result["frames"]) # (T, H, W, 3) uint8 tensor or None

    # From another process, connect to the same actor by name:
    worker2 = RayEvalWorker.from_name("my_eval_worker")
"""
from __future__ import annotations

import importlib
import logging
from collections.abc import Callable
from typing import Any

import numpy as np

_has_ray = importlib.util.find_spec("ray") is not None

logger = logging.getLogger(__name__)


[docs] class RayEvalWorker: """Asynchronous evaluation worker backed by a Ray actor. The worker creates a **new Python process** (via Ray) and inside it: 1. Calls *init_fn* -- use this for any process-level setup that must happen before other imports (e.g. Isaac Lab ``AppLauncher``). 2. Creates the environment via *env_maker*. 3. Creates the policy via *policy_maker(env)*. Thereafter, :meth:`submit` sends new policy weights and triggers an evaluation rollout. :meth:`poll` returns the result (reward and optional video frames) when the rollout finishes, or ``None`` if it is still running. If a *name* is provided the actor is registered with Ray under that name, allowing other processes (or a later session) to reconnect to the same running actor via :meth:`from_name`. Args: init_fn: Optional callable invoked at the very start of the actor process, before *env_maker* or *policy_maker*. All imports should be **local** inside this callable so that the actor's fresh Python process can control import order. Set to ``None`` to skip. env_maker: Callable that returns a TorchRL environment. Called once inside the actor after *init_fn*. If the underlying environment supports ``render_mode="rgb_array"``, the actor will call ``render()`` on each evaluation step and return the frames. policy_maker: Callable ``(env) -> policy`` that builds the policy module given the environment. Called once inside the actor after the environment has been created. num_gpus: Number of GPUs to request from Ray for this actor. Defaults to 1. reward_keys: Nested key(s) used to read the reward from the rollout tensordict. Defaults to ``("next", "reward")``. name: Optional name for the Ray actor. When set, the actor is registered under this name and can be retrieved later with :meth:`from_name`. **remote_kwargs: Extra keyword arguments forwarded to ``ray.remote()`` when creating the actor class (e.g. ``num_cpus``, ``runtime_env``). """ def __init__( self, init_fn: Callable[[], None] | None, env_maker: Callable[[], Any], policy_maker: Callable[[Any], Any], *, num_gpus: int = 1, reward_keys: tuple[str, ...] = ("next", "reward"), name: str | None = None, **remote_kwargs: Any, ) -> None: if not _has_ray: raise RuntimeError( "Ray is required for RayEvalWorker but could not be found. " "Install it with: pip install ray" ) import ray self._reward_keys = reward_keys # Build the remote actor class dynamically so that the caller does not # need to depend on Ray at import time. actor_cls = ray.remote(num_gpus=num_gpus, **remote_kwargs)(_EvalActor) actor_kwargs = {} if name is not None: actor_kwargs["name"] = name actor_kwargs["lifetime"] = "detached" self._actor = actor_cls.options(**actor_kwargs).remote( init_fn, env_maker, policy_maker ) self._pending_ref: ray.ObjectRef | None = None # ------------------------------------------------------------------ # Alternative constructors # ------------------------------------------------------------------
[docs] @classmethod def from_name( cls, name: str, *, reward_keys: tuple[str, ...] = ("next", "reward"), ) -> RayEvalWorker: """Connect to an existing named :class:`RayEvalWorker` actor. This is useful when one process creates the worker (with a *name*) and another process wants to submit evaluations or poll results on the same actor. Args: name: The actor name that was passed to the constructor. reward_keys: Nested key(s) used to read the reward from the rollout tensordict. Defaults to ``("next", "reward")``. """ if not _has_ray: raise RuntimeError( "Ray is required for RayEvalWorker but could not be found. " "Install it with: pip install ray" ) import ray worker = object.__new__(cls) worker._reward_keys = reward_keys worker._actor = ray.get_actor(name) worker._pending_ref = None return worker
# ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------
[docs] def submit( self, weights: Any, max_steps: int, *, deterministic: bool = True, break_when_any_done: bool = True, ) -> None: """Start an asynchronous evaluation rollout. If a previous rollout is still running its result is silently discarded (fire-and-forget semantics). Args: weights: Policy weights, typically obtained via ``TensorDict.from_module(policy).data.detach().cpu()``. max_steps: Maximum number of environment steps per rollout. deterministic: If ``True``, use deterministic exploration. break_when_any_done: If ``True``, stop the rollout as soon as any sub-environment reports ``done``. """ # Discard any previous un-polled result self._pending_ref = self._actor.eval.remote( weights, max_steps, self._reward_keys, deterministic, break_when_any_done, )
[docs] def poll(self, timeout: float = 0) -> dict | None: """Return the evaluation result if ready, otherwise ``None``. The returned dict contains: - ``"reward"`` -- scalar mean episode reward. - ``"frames"`` -- ``(T, H, W, 3)`` uint8 CPU tensor of rendered frames, or ``None`` if the environment does not render. Args: timeout: Seconds to wait for the result. ``0`` means non-blocking (return immediately if not ready). """ if self._pending_ref is None: return None import ray ready, _ = ray.wait([self._pending_ref], timeout=timeout) if not ready: return None result = ray.get(self._pending_ref) self._pending_ref = None return result
[docs] def shutdown(self) -> None: """Close the environment and kill the actor.""" import ray try: ray.get(self._actor.shutdown.remote()) except Exception: logger.warning("RayEvalWorker: error during shutdown", exc_info=True) ray.kill(self._actor) self._actor = None self._pending_ref = None
# ====================================================================== # Inner actor -- runs inside the Ray worker process # ====================================================================== class _EvalActor: """Plain class turned into a Ray actor by :class:`RayEvalWorker`. **Why local imports?** Environments like Isaac Lab **require** their ``AppLauncher`` to be initialised before ``import torch`` even happens. The *init_fn* callback (called first in ``__init__``) takes care of that. Every ``import torch`` and ``from torchrl ...`` therefore lives inside a method body so that the actor process can control import order via *init_fn*. Only stdlib / pure-Python imports (``numpy``, ``logging``, etc.) are safe at module level. """ def __init__( self, init_fn: Callable[[], None] | None, env_maker: Callable[[], Any], policy_maker: Callable[[Any], Any], ) -> None: # --- process-level initialisation --- # This MUST run before any torch import. For Isaac Lab the init_fn # calls AppLauncher which configures the GPU, the Omniverse runtime, # and various environment variables that torch and CUDA rely on. if init_fn is not None: init_fn() # --- now safe to import torch / torchrl --- # (kept local: see class docstring for rationale) import torch # noqa: F401 self.env = env_maker() self.policy = policy_maker(self.env) # Cache device before any to_module call can replace nn.Parameter # with plain tensors (which makes .parameters() empty). self._device = next(self.policy.parameters()).device def eval( self, weights, max_steps: int, reward_keys: tuple[str, ...], deterministic: bool, break_when_any_done: bool, ) -> dict: """Run an evaluation rollout with the given weights.""" # Local imports: torch/torchrl must not be imported before init_fn # has run (see class docstring -- this is critical for Isaac Lab). import torch from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp # Load weights into the eval policy (move to policy device first) weights.to(self._device).to_module(self.policy) frames = [] total_reward = 0.0 num_steps = 0 exploration = ( ExplorationType.DETERMINISTIC if deterministic else ExplorationType.RANDOM ) with set_exploration_type(exploration), torch.no_grad(): td = self.env.reset() for _i in range(max_steps): td = self.policy(td) td = self.env.step(td) total_reward += td[reward_keys].mean().item() num_steps += 1 frame = self._try_render() if frame is not None: frames.append(frame) done = td.get(("next", "done"), None) if break_when_any_done and done is not None and done.any(): break td = step_mdp(td) mean_reward = total_reward / max(1, num_steps) # Format video: (1, T, C, H, W) uint8 CPU tensor video = None if frames: video = torch.stack(frames, dim=0).unsqueeze(0).cpu() return {"reward": mean_reward, "frames": video} def _try_render(self): """Render one frame from the underlying environment. Walks the wrapper chain to find a callable ``render()`` method and returns the result as a ``(C, H, W)`` uint8 tensor, or ``None`` if rendering is unavailable. """ # Local import: torch must not be imported at module level # (see class docstring -- this is critical for Isaac Lab). import torch # Walk through TransformedEnv / wrapper chain to the base env. env = self.env while hasattr(env, "base_env"): env = env.base_env render_fn = getattr(env, "render", None) # If the base env delegates to a gymnasium env, prefer that. if hasattr(env, "_env") and hasattr(env._env, "render"): render_fn = env._env.render if render_fn is None: return None raw = render_fn() if raw is None: return None if isinstance(raw, np.ndarray): raw = torch.from_numpy(raw.copy()) # (H, W, C) -> (C, H, W) if raw.ndim == 3 and raw.shape[-1] in (3, 4): raw = raw[..., :3] raw = raw.permute(2, 0, 1) return raw.to(torch.uint8) def shutdown(self) -> None: """Shut down the environment.""" if hasattr(self, "env") and not self.env.is_closed: self.env.close()

Docs

Lorem ipsum dolor sit amet, consectetur

View Docs

Tutorials

Lorem ipsum dolor sit amet, consectetur

View Tutorials

Resources

Lorem ipsum dolor sit amet, consectetur

View Resources