Shortcuts

Source code for torchrl.collectors._async_batched

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import queue
import threading
from collections import deque, OrderedDict
from collections.abc import Callable, Iterator, Sequence
from typing import Literal

import torch
from tensordict import lazy_stack, TensorDictBase

from torchrl._utils import logger as torchrl_logger
from torchrl.collectors._base import BaseCollector
from torchrl.envs import AsyncEnvPool, EnvBase
from torchrl.modules.inference_server import InferenceServer, ThreadingTransport
from torchrl.modules.inference_server._transport import InferenceTransport

_ENV_IDX_KEY = "env_index"

_POLICY_BACKENDS = ("threading", "multiprocessing", "ray", "monarch")
_ENV_BACKENDS = ("threading", "multiprocessing")


def _make_transport(
    policy_backend: str, num_slots: int | None = None
) -> InferenceTransport:
    """Create an :class:`InferenceTransport` from a backend name.

    Args:
        policy_backend: one of ``"threading"``, ``"multiprocessing"``,
            ``"ray"``, or ``"monarch"``.
        num_slots: when set and ``policy_backend="threading"``, a
            :class:`~torchrl.modules.SlotTransport` is created instead of
            the generic :class:`~torchrl.modules.ThreadingTransport`.
    """
    if policy_backend == "threading":
        if num_slots is not None:
            from torchrl.modules.inference_server._slot import SlotTransport

            return SlotTransport(num_slots)
        return ThreadingTransport()
    if policy_backend == "multiprocessing":
        from torchrl.modules.inference_server._mp import MPTransport

        return MPTransport()
    if policy_backend == "ray":
        from torchrl.modules.inference_server._ray import RayTransport

        return RayTransport()
    if policy_backend == "monarch":
        from torchrl.modules.inference_server._monarch import MonarchTransport

        return MonarchTransport()
    raise ValueError(
        f"Unknown policy_backend {policy_backend!r}. "
        f"Expected one of {_POLICY_BACKENDS}."
    )


def _env_loop(
    pool: AsyncEnvPool,
    env_id: int,
    transport: InferenceTransport,
    result_queue: queue.Queue,
    shutdown_event: threading.Event,
):
    """Per-env worker thread using pool slot for env execution and InferenceServer for policy.

    Each thread owns one slot in the :class:`~torchrl.envs.AsyncEnvPool` and
    one inference client.  The pool handles the actual environment execution in
    whatever backend it was configured with (threading, multiprocessing, etc.),
    while this thread coordinates the send/recv cycle and inference submission.

        reset -> infer (blocking) -> step_send -> step_recv -> put transition -> infer -> ...
    """
    client = transport.client()

    try:
        pool.async_reset_send(env_index=env_id)
        obs = pool.async_reset_recv(env_index=env_id)
        action_td = client(obs)

        while not shutdown_event.is_set():
            pool.async_step_and_maybe_reset_send(action_td, env_index=env_id)
            cur_td, next_obs = pool.async_step_and_maybe_reset_recv(env_index=env_id)
            cur_td.set(_ENV_IDX_KEY, env_id)
            result_queue.put(cur_td)
            if shutdown_event.is_set():
                break
            action_td = client(next_obs)
    except Exception as exc:
        if not shutdown_event.is_set():
            result_queue.put(exc)


[docs] class AsyncBatchedCollector(BaseCollector): """Asynchronous collector that pairs per-env threads with an :class:`~torchrl.envs.AsyncEnvPool` and an :class:`~torchrl.modules.InferenceServer`. Unlike :class:`~torchrl.collectors.Collector`, this collector fully decouples environment stepping from policy inference: * An :class:`~torchrl.envs.AsyncEnvPool` runs *N* environments using whatever backend the user chooses (``"threading"``, ``"multiprocessing"``). * *N* lightweight coordinator threads -- one per environment -- each own a slot in the pool and an inference client. A thread sends its env's observation to the :class:`~torchrl.modules.InferenceServer`, blocks until the batched action is returned, then sends the action back to the pool for stepping. * The :class:`~torchrl.modules.InferenceServer` running in a background thread continuously drains observation submissions, batches them, runs a single forward pass, and fans actions back out. There is **no global synchronisation barrier**: fast environments keep stepping while slow ones wait for inference, and the server always processes whatever observations have accumulated. The user simply provides env factories and a policy; the collector handles all wiring internally. Args: create_env_fn (list[Callable[[], EnvBase]]): a list of callables, each returning an :class:`~torchrl.envs.EnvBase` instance. The list length determines the number of parallel environments. Keyword Args: policy (nn.Module or Callable, optional): the policy module. Mutually exclusive with ``policy_factory``. policy_factory (Callable[[], Callable], optional): a zero-argument callable that returns the policy. Useful when the policy cannot be pickled. Mutually exclusive with ``policy``. frames_per_batch (int): number of environment frames to collect per batch. Required. total_frames (int, optional): total number of frames the collector should return during its lifespan. ``-1`` means endless. Defaults to ``-1``. max_batch_size (int, optional): upper bound on the number of requests the inference server processes in a single forward pass. Defaults to ``64``. min_batch_size (int, optional): minimum number of requests the inference server accumulates before dispatching a batch. After the first request arrives the server keeps draining for up to ``server_timeout`` seconds until this many items are collected. ``1`` (default) dispatches immediately. server_timeout (float, optional): seconds the server waits for work before dispatching a partial batch. Defaults to ``0.01``. transport (InferenceTransport, optional): a pre-built transport object. When provided, it takes precedence over ``policy_backend``. When ``None`` (default) a transport is created automatically from the resolved ``policy_backend``. device (torch.device or str, optional): device for policy inference. Passed to the inference server. Defaults to ``None``. backend (str, optional): global default backend for both environments and policy inference. Specific overrides ``env_backend`` and ``policy_backend`` take precedence when set. One of ``"threading"``, ``"multiprocessing"``, ``"ray"``, or ``"monarch"``. Defaults to ``"threading"``. env_backend (str, optional): backend for the :class:`~torchrl.envs.AsyncEnvPool` that runs environments. One of ``"threading"`` or ``"multiprocessing"``. Falls back to ``backend`` when ``None``. The coordinator threads are always Python threads regardless of this setting. Defaults to ``None``. policy_backend (str, optional): backend for the inference transport used to communicate with the :class:`~torchrl.modules.InferenceServer`. One of ``"threading"``, ``"multiprocessing"``, ``"ray"``, or ``"monarch"``. Falls back to ``backend`` when ``None``. Defaults to ``None``. reset_at_each_iter (bool, optional): whether to reset all envs at the start of every collection batch. Defaults to ``False``. postproc (Callable, optional): post-processing transform applied to each collected batch before yielding. Defaults to ``None``. yield_completed_trajectories (bool, optional): if ``True``, the collector yields individual completed trajectories as they finish rather than fixed-size batches. ``frames_per_batch`` acts as the *minimum* number of frames to accumulate before yielding. Defaults to ``False``. weight_sync: an optional :class:`~torchrl.weight_update.WeightSyncScheme` forwarded to the inference server for receiving weight updates. weight_sync_model_id (str, optional): model id for weight sync. Defaults to ``"policy"``. verbose (bool, optional): if ``True``, log progress messages. Defaults to ``False``. create_env_kwargs (dict or list[dict], optional): keyword arguments forwarded to each environment factory. A single dict is broadcast to all factories. Examples: >>> from torchrl.collectors import AsyncBatchedCollector >>> from torchrl.envs import GymEnv >>> from tensordict.nn import TensorDictModule >>> import torch.nn as nn >>> policy = TensorDictModule( ... nn.Linear(4, 2), in_keys=["observation"], out_keys=["action"] ... ) >>> collector = AsyncBatchedCollector( ... create_env_fn=[lambda: GymEnv("CartPole-v1")] * 4, ... policy=policy, ... frames_per_batch=200, ... total_frames=1000, ... ) >>> for batch in collector: ... print(batch.shape) ... break >>> collector.shutdown() """ def __init__( self, create_env_fn: list[Callable[[], EnvBase]], *, policy: Callable | None = None, policy_factory: Callable[[], Callable] | None = None, frames_per_batch: int, total_frames: int = -1, max_batch_size: int = 64, min_batch_size: int = 1, server_timeout: float = 0.01, transport: InferenceTransport | None = None, device: torch.device | str | None = None, backend: Literal[ "threading", "multiprocessing", "ray", "monarch" ] = "threading", env_backend: Literal["threading", "multiprocessing"] | None = None, policy_backend: Literal["threading", "multiprocessing", "ray", "monarch"] | None = None, reset_at_each_iter: bool = False, postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, yield_completed_trajectories: bool = False, weight_sync=None, weight_sync_model_id: str = "policy", verbose: bool = False, create_env_kwargs: dict | list[dict] | None = None, ): if policy is not None and policy_factory is not None: raise TypeError("policy and policy_factory are mutually exclusive.") if policy is None and policy_factory is None: raise TypeError("One of policy or policy_factory must be provided.") # ---- resolve policy --------------------------------------------------- if policy_factory is not None: policy = policy_factory() self._policy = policy # ---- env config ------------------------------------------------------- if not isinstance(create_env_fn, Sequence): raise TypeError("create_env_fn must be a list of env factories.") self._create_env_fn = list(create_env_fn) self._num_envs = len(create_env_fn) self._create_env_kwargs = create_env_kwargs # ---- resolve backends ------------------------------------------------- effective_env_backend = env_backend if env_backend is not None else backend effective_policy_backend = ( policy_backend if policy_backend is not None else backend ) if effective_env_backend not in _ENV_BACKENDS: raise ValueError( f"env_backend={effective_env_backend!r} is not supported. " f"Expected one of {_ENV_BACKENDS}." ) self._env_backend = effective_env_backend self._policy_backend = effective_policy_backend # ---- build transport -------------------------------------------------- if transport is None: transport = _make_transport( effective_policy_backend, num_slots=self._num_envs ) self._transport = transport # ---- build inference server ------------------------------------------- self._server = InferenceServer( model=policy, transport=transport, max_batch_size=max_batch_size, min_batch_size=min_batch_size, timeout=server_timeout, device=device, weight_sync=weight_sync, weight_sync_model_id=weight_sync_model_id, ) # ---- collector settings ----------------------------------------------- self.requested_frames_per_batch = frames_per_batch self.frames_per_batch = frames_per_batch self.total_frames = total_frames self.reset_at_each_iter = reset_at_each_iter self.yield_completed_trajectories = yield_completed_trajectories self._postproc = postproc self.verbose = verbose self._frames = 0 self._iter = -1 # ---- runtime state (created lazily) ----------------------------------- self._shutdown_event: threading.Event | None = None self._result_queue: queue.Queue | None = None self._env_pool: AsyncEnvPool | None = None self._workers: list[threading.Thread] = [] # Per-env trajectory accumulators (for yield_completed_trajectories) self._yield_queues: list[deque] = [deque() for _ in range(self._num_envs)] self._trajectory_queue: deque = deque() # ------------------------------------------------------------------ # Lifecycle # ------------------------------------------------------------------ def _ensure_started(self) -> None: """Create the env pool, start the server and per-env threads.""" if self._workers and all(w.is_alive() for w in self._workers): return # Build env pool kwargs = {} if self._create_env_kwargs is not None: kwargs["create_env_kwargs"] = self._create_env_kwargs self._env_pool = AsyncEnvPool( self._create_env_fn, backend=self._env_backend, **kwargs, ) # Start inference server if not self._server.is_alive: self._server.start() # Start per-env coordinator threads self._result_queue = queue.Queue() self._shutdown_event = threading.Event() self._workers = [] for i in range(self._num_envs): t = threading.Thread( target=_env_loop, kwargs={ "pool": self._env_pool, "env_id": i, "transport": self._transport, "result_queue": self._result_queue, "shutdown_event": self._shutdown_event, }, daemon=True, name=f"AsyncBatchedCollector-env-{i}", ) self._workers.append(t) t.start() @property def env(self) -> AsyncEnvPool: """The underlying :class:`AsyncEnvPool`.""" self._ensure_started() return self._env_pool @property def policy(self) -> Callable: """The policy passed to the inference server.""" return self._policy # ------------------------------------------------------------------ # Rollout: drain the result queue # ------------------------------------------------------------------ @staticmethod def _check_worker_result(item): """Re-raise exceptions propagated from coordinator threads.""" if isinstance(item, BaseException): raise RuntimeError( "A collector worker thread raised an exception." ) from item def _rollout_frames(self) -> TensorDictBase: """Drain ``frames_per_batch`` transitions from the workers.""" rq = self._result_queue collected = 0 transitions: list[TensorDictBase] = [] while collected < self.frames_per_batch: # Block for at least one transition td = rq.get() self._check_worker_result(td) transitions.append(td) collected += td.numel() # Batch-drain any additional items already in the queue while collected < self.frames_per_batch: try: td = rq.get_nowait() except queue.Empty: break transitions.append(td) collected += td.numel() if self.verbose: torchrl_logger.debug( f"AsyncBatchedCollector: {collected}/{self.frames_per_batch} frames" ) return lazy_stack(transitions) def _rollout_yield_trajs(self) -> TensorDictBase: """Drain transitions until a complete trajectory is available.""" rq = self._result_queue while not self._trajectory_queue: td = rq.get() self._check_worker_result(td) env_id = 0 eid = td.get(_ENV_IDX_KEY, default=None) if eid is not None: # Unwrap NonTensorData / NonTensorStack / list wrappers if hasattr(eid, "data"): eid = eid.data while isinstance(eid, (list,)) and len(eid) == 1: eid = eid[0] env_id = int(eid) self._yield_queues[env_id].append(td) if td["next", "done"].any(): self._trajectory_queue.append( lazy_stack(list(self._yield_queues[env_id]), -1) ) self._yield_queues[env_id].clear() result = self._trajectory_queue.popleft() return result.reshape(-1) @property def rollout(self) -> Callable[[], TensorDictBase]: if self.yield_completed_trajectories: return self._rollout_yield_trajs return self._rollout_frames # ------------------------------------------------------------------ # BaseCollector interface # ------------------------------------------------------------------
[docs] def iterator(self) -> Iterator[TensorDictBase]: """Iterate over collected batches.""" self._ensure_started() total = self.total_frames while total < 0 or self._frames < total: self._iter += 1 td = self.rollout() self._frames += td.numel() if self._postproc is not None: td = self._postproc(td) yield td
[docs] def shutdown( self, timeout: float | None = None, close_env: bool = True, raise_on_error: bool = True, ) -> None: """Shut down the collector, inference server, threads and env pool.""" if self._shutdown_event is not None: self._shutdown_event.set() _timeout = timeout or 5.0 for w in self._workers: w.join(timeout=_timeout) self._workers = [] self._server.shutdown(timeout=_timeout) if close_env and self._env_pool is not None: self._env_pool.close(raise_if_closed=raise_on_error) self._env_pool = None
[docs] def set_seed(self, seed: int, static_seed: bool = False) -> int: """Set the seed (no-op; envs are created inside the pool).""" return seed
def state_dict(self) -> OrderedDict: return OrderedDict() def load_state_dict(self, state_dict: OrderedDict) -> None: pass def __del__(self) -> None: if getattr(self, "_workers", None): try: self.shutdown(timeout=2.0, raise_on_error=False) except Exception: pass

Docs

Lorem ipsum dolor sit amet, consectetur

View Docs

Tutorials

Lorem ipsum dolor sit amet, consectetur

View Tutorials

Resources

Lorem ipsum dolor sit amet, consectetur

View Resources