Shortcuts

Source code for torchrl.collectors._multi_base

from __future__ import annotations

import _pickle
import abc

import contextlib
import warnings
from collections import OrderedDict
from collections.abc import Callable, Mapping, Sequence
from typing import Any

import numpy as np
import torch
from tensordict import TensorDict, TensorDictBase
from tensordict.nn import CudaGraphModule, TensorDictModule
from tensordict.utils import _zip_strict
from torch import multiprocessing as mp, nn
from torchrl import logger as torchrl_logger
from torchrl._utils import (
    _check_for_faulty_process,
    _get_mp_ctx,
    _ProcessNoWarn,
    _set_mp_start_method_if_unset,
    RL_WARNINGS,
)
from torchrl.collectors._base import BaseCollector
from torchrl.collectors._constants import (
    _InterruptorManager,
    _is_osx,
    DEFAULT_EXPLORATION_TYPE,
    ExplorationType,
    INSTANTIATE_TIMEOUT,
)
from torchrl.collectors._runner import _main_async_collector
from torchrl.collectors._single import Collector
from torchrl.collectors.utils import _make_meta_policy, _TrajectoryPool
from torchrl.collectors.weight_update import WeightUpdaterBase
from torchrl.data import ReplayBuffer
from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING
from torchrl.envs import EnvBase, EnvCreator
from torchrl.envs.llm.transforms import PolicyVersion
from torchrl.weight_update import (
    MultiProcessWeightSyncScheme,
    SharedMemWeightSyncScheme,
    WeightSyncScheme,
)
from torchrl.weight_update.utils import _resolve_model


class _MultiCollectorMeta(abc.ABCMeta):
    """Metaclass for MultiCollector that dispatches based on sync parameter.

    When MultiCollector is instantiated with sync=True or sync=False, the metaclass
    intercepts the call and returns the appropriate subclass instance:
    - sync=True: returns MultiSyncCollector (alias: MultiSyncCollector)
    - sync=False: returns MultiAsyncCollector (alias: MultiAsyncCollector)
    """

    def __call__(cls, *args, sync: bool | None = None, **kwargs):
        # Only dispatch if we're instantiating MultiCollector directly (not a subclass)
        # and sync is explicitly provided
        if cls.__name__ == "MultiCollector" and sync is not None:
            if sync:
                from torchrl.collectors._multi_sync import MultiSyncCollector

                return MultiSyncCollector(*args, **kwargs)
            else:
                from torchrl.collectors._multi_async import MultiAsyncCollector

                return MultiAsyncCollector(*args, **kwargs)
        return super().__call__(*args, **kwargs)


[docs]class MultiCollector(BaseCollector, metaclass=_MultiCollectorMeta): """Runs a given number of DataCollectors on separate processes. Args: create_env_fn (List[Callabled]): list of Callables, each returning an instance of :class:`~torchrl.envs.EnvBase`. policy (Callable): Policy to be executed in the environment. Must accept :class:`tensordict.tensordict.TensorDictBase` object as input. If ``None`` is provided (default), the policy used will be a :class:`~torchrl.collectors.RandomPolicy` instance with the environment ``action_spec``. Accepted policies are usually subclasses of :class:`~tensordict.nn.TensorDictModuleBase`. This is the recommended usage of the collector. Other callables are accepted too: If the policy is not a ``TensorDictModuleBase`` (e.g., a regular :class:`~torch.nn.Module` instances) it will be wrapped in a `nn.Module` first. Then, the collector will try to assess if these modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not. - If the policy forward signature matches any of ``forward(self, tensordict)``, ``forward(self, td)`` or ``forward(self, <anything>: TensorDictBase)`` (or any typing with a single argument typed as a subclass of ``TensorDictBase``) then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`. - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``. .. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized / pickled directly), the ``policy_factory`` should be used instead. .. note:: When using ``weight_sync_schemes``, both ``policy`` and ``policy_factory`` can be provided together. In this case, the ``policy`` is used ONLY for weight extraction (via ``TensorDict.from_module()``) to set up weight synchronization, but it is NOT sent to workers and its weights are NOT depopulated. The ``policy_factory`` is what actually gets passed to workers to create their local policy instances. This is useful when the policy is hard to serialize but you have a copy on the main node for weight synchronization purposes. Keyword Args: sync (bool, optional): if ``True``, the collector will run in sync mode (:class:`~torchrl.collectors.MultiSyncCollector`). If `False`, the collector will run in async mode (:class:`~torchrl.collectors.MultiAsyncCollector`). policy_factory (Callable[[], Callable], list of Callable[[], Callable], optional): a callable (or list of callables) that returns a policy instance. When not using ``weight_sync_schemes``, this is mutually exclusive with the ``policy`` argument. When using ``weight_sync_schemes``, both ``policy`` and ``policy_factory`` can be provided: the ``policy`` is used for weight extraction only, while ``policy_factory`` creates policies on workers. .. note:: `policy_factory` comes in handy whenever the policy cannot be serialized. .. warning:: `policy_factory` is currently not compatible with multiprocessed data collectors. num_workers (int, optional): number of workers to use. If `create_env_fn` is a list, this will be ignored. Defaults to `None` (workers determined by the `create_env_fn` length). frames_per_batch (int, Sequence[int]): A keyword-only argument representing the total number of elements in a batch. If a sequence is provided, represents the number of elements in a batch per worker. Total number of elements in a batch is then the sum over the sequence. total_frames (int, optional): A keyword-only argument representing the total number of frames returned by the collector during its lifespan. If the ``total_frames`` is not divisible by ``frames_per_batch``, an exception is raised. Endless collectors can be created by passing ``total_frames=-1``. Defaults to ``-1`` (never ending collector). device (int, str or torch.device, optional): The generic device of the collector. The ``device`` args fills any non-specified device: if ``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or ``env_device`` is not specified, its value will be set to ``device``. Defaults to ``None`` (No default device). Supports a list of devices if one wishes to indicate a different device for each worker. The list must be as long as the number of workers. storing_device (int, str or torch.device, optional): The device on which the output :class:`~tensordict.TensorDict` will be stored. If ``device`` is passed and ``storing_device`` is ``None``, it will default to the value indicated by ``device``. For long trajectories, it may be necessary to store the data on a different device than the one where the policy and env are executed. Defaults to ``None`` (the output tensordict isn't on a specific device, leaf tensors sit on the device where they were created). Supports a list of devices if one wishes to indicate a different device for each worker. The list must be as long as the number of workers. env_device (int, str or torch.device, optional): The device on which the environment should be cast (or executed if that functionality is supported). If not specified and the env has a non-``None`` device, ``env_device`` will default to that value. If ``device`` is passed and ``env_device=None``, it will default to ``device``. If the value as such specified of ``env_device`` differs from ``policy_device`` and one of them is not ``None``, the data will be cast to ``env_device`` before being passed to the env (i.e., passing different devices to policy and env is supported). Defaults to ``None``. Supports a list of devices if one wishes to indicate a different device for each worker. The list must be as long as the number of workers. policy_device (int, str or torch.device, optional): The device on which the policy should be cast. If ``device`` is passed and ``policy_device=None``, it will default to ``device``. If the value as such specified of ``policy_device`` differs from ``env_device`` and one of them is not ``None``, the data will be cast to ``policy_device`` before being passed to the policy (i.e., passing different devices to policy and env is supported). Defaults to ``None``. Supports a list of devices if one wishes to indicate a different device for each worker. The list must be as long as the number of workers. create_env_kwargs (dict, optional): A dictionary with the keyword arguments used to create an environment. If a list is provided, each of its elements will be assigned to a sub-collector. collector_class (Python class or constructor): a collector class to be remotely instantiated. Can be :class:`~torchrl.collectors.Collector`, :class:`~torchrl.collectors.MultiSyncCollector`, :class:`~torchrl.collectors.MultiAsyncCollector` or a derived class of these. Defaults to :class:`~torchrl.collectors.Collector`. max_frames_per_traj (int, optional): Maximum steps per trajectory. Note that a trajectory can span across multiple batches (unless ``reset_at_each_iter`` is set to ``True``, see below). Once a trajectory reaches ``n_steps``, the environment is reset. If the environment wraps multiple environments together, the number of steps is tracked for each environment independently. Negative values are allowed, in which case this argument is ignored. Defaults to ``None`` (i.e. no maximum number of steps). init_random_frames (int, optional): Number of frames for which the policy is ignored before it is called. This feature is mainly intended to be used in offline/model-based settings, where a batch of random trajectories can be used to initialize training. If provided, it will be rounded up to the closest multiple of frames_per_batch. Defaults to ``None`` (i.e. no random frames). reset_at_each_iter (bool, optional): Whether environments should be reset at the beginning of a batch collection. Defaults to ``False``. postproc (Callable, optional): A post-processing transform, such as a :class:`~torchrl.envs.Transform` or a :class:`~torchrl.data.postprocs.MultiStep` instance. Defaults to ``None``. split_trajs (bool, optional): Boolean indicating whether the resulting TensorDict should be split according to the trajectories. See :func:`~torchrl.collectors.utils.split_trajectories` for more information. Defaults to ``False``. exploration_type (ExplorationType, optional): interaction mode to be used when collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``, ``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE`` or ``torchrl.envs.utils.ExplorationType.MEAN``. reset_when_done (bool, optional): if ``True`` (default), an environment that return a ``True`` value in its ``"done"`` or ``"truncated"`` entry will be reset at the corresponding indices. update_at_each_batch (boolm optional): if ``True``, :meth:`update_policy_weights_()` will be called before (sync) or after (async) each data collection. Defaults to ``False``. preemptive_threshold (:obj:`float`, optional): a value between 0.0 and 1.0 that specifies the ratio of workers that will be allowed to finished collecting their rollout before the rest are forced to end early. num_threads (int, optional): number of threads for this process. Defaults to the number of workers. num_sub_threads (int, optional): number of threads of the subprocesses. Should be equal to one plus the number of processes launched within each subprocess (or one if a single process is launched). Defaults to 1 for safety: if none is indicated, launching multiple workers may charge the cpu load too much and harm performance. cat_results (str, int or None): (:class:`~torchrl.collectors.MultiSyncCollector` exclusively). If ``"stack"``, the data collected from the workers will be stacked along the first dimension. This is the preferred behavior as it is the most compatible with the rest of the library. If ``0``, results will be concatenated along the first dimension of the outputs, which can be the batched dimension if the environments are batched or the time dimension if not. A ``cat_results`` value of ``-1`` will always concatenate results along the time dimension. This should be preferred over the default. Intermediate values are also accepted. Defaults to ``"stack"``. .. note:: From v0.5, this argument will default to ``"stack"`` for a better interoperability with the rest of the library. set_truncated (bool, optional): if ``True``, the truncated signals (and corresponding ``"done"`` but not ``"terminated"``) will be set to ``True`` when the last frame of a rollout is reached. If no ``"truncated"`` key is found, an exception is raised. Truncated keys can be set through ``env.add_truncated_keys``. Defaults to ``False``. use_buffers (bool, optional): if ``True``, a buffer will be used to stack the data. This isn't compatible with environments with dynamic specs. Defaults to ``True`` for envs without dynamic specs, ``False`` for others. replay_buffer (ReplayBuffer, optional): if provided, the collector will not yield tensordicts but populate the buffer instead. Defaults to ``None``. extend_buffer (bool, optional): if `True`, the replay buffer is extended with entire rollouts and not with single steps. Defaults to `True` for multiprocessed data collectors. local_init_rb (bool, optional): if ``False``, the collector will use fake data to initialize the replay buffer in the main process (legacy behavior). If ``True``, the storage-level coordination will handle initialization with real data from worker processes. Defaults to ``None``, which maintains backward compatibility but shows a deprecation warning. This parameter is deprecated and will be removed in v0.12. trust_policy (bool, optional): if ``True``, a non-TensorDictModule policy will be trusted to be assumed to be compatible with the collector. This defaults to ``True`` for CudaGraphModules and ``False`` otherwise. compile_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be compiled using :func:`~torch.compile` default behaviour. If a dictionary of kwargs is passed, it will be used to compile the policy. cudagraph_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be wrapped in :class:`~tensordict.nn.CudaGraphModule` with default kwargs. If a dictionary of kwargs is passed, it will be used to wrap the policy. no_cuda_sync (bool): if ``True``, explicit CUDA synchronizations calls will be bypassed. For environments running directly on CUDA (`IsaacLab <https://github.com/isaac-sim/IsaacLab/>`_ or `ManiSkills <https://github.com/haosulab/ManiSkill/>`_) cuda synchronization may cause unexpected crashes. Defaults to ``False``. weight_updater (WeightUpdaterBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdaterBase` or its subclass, responsible for updating the policy weights on remote inference workers. If not provided, a :class:`~torchrl.collectors.MultiProcessedWeightUpdater` will be used by default, which handles weight synchronization across multiple processes. Consider using a constructor if the updater needs to be serialized. weight_sync_schemes (dict[str, WeightSyncScheme], optional): Dictionary of weight sync schemes for SENDING weights to worker sub-collectors. Keys are model identifiers (e.g., "policy") and values are WeightSyncScheme instances configured to send weights to child processes. If not provided, a :class:`~torchrl.collectors.MultiProcessWeightSyncScheme` will be used by default. This is for propagating weights DOWN the hierarchy (parent -> children). weight_recv_schemes (dict[str, WeightSyncScheme], optional): Dictionary of weight sync schemes for RECEIVING weights from parent collectors. Keys are model identifiers (e.g., "policy") and values are WeightSyncScheme instances configured to receive weights. This enables cascading in hierarchies like: RPCDataCollector -> MultiSyncCollector -> Collector. Received weights are automatically propagated to sub-collectors if matching model_ids exist. Defaults to ``None``. track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy. This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment. Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track the policy version. Defaults to `False`. worker_idx (int, optional): the index of the worker. Examples: >>> from torchrl.collectors import MultiCollector >>> from torchrl.envs import GymEnv >>> >>> def make_env(): ... return GymEnv("CartPole-v1") >>> >>> # Synchronous collection (for on-policy algorithms like PPO) >>> sync_collector = MultiCollector( ... create_env_fn=[make_env] * 4, # 4 parallel workers ... policy=my_policy, ... frames_per_batch=1000, ... total_frames=100000, ... sync=True, # All workers complete before batch is delivered ... ) >>> >>> # Asynchronous collection (for off-policy algorithms like SAC) >>> async_collector = MultiCollector( ... create_env_fn=[make_env] * 4, ... policy=my_policy, ... frames_per_batch=1000, ... total_frames=100000, ... sync=False, # First-come-first-serve delivery ... ) >>> >>> # Iterate over collected data >>> for data in sync_collector: ... # data is a TensorDict with collected transitions ... pass >>> sync_collector.shutdown() """ def __init__( self, create_env_fn: Sequence[Callable[[], EnvBase]], policy: None | (TensorDictModule | Callable[[TensorDictBase], TensorDictBase]) = None, *, num_workers: int | None = None, policy_factory: Callable[[], Callable] | list[Callable[[], Callable]] | None = None, frames_per_batch: int | Sequence[int], total_frames: int | None = -1, device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, storing_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, env_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, policy_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, create_env_kwargs: Sequence[dict] | None = None, collector_class: type | Callable[[], BaseCollector] | None = None, max_frames_per_traj: int | None = None, init_random_frames: int | None = None, reset_at_each_iter: bool = False, postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, split_trajs: bool | None = None, exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, reset_when_done: bool = True, update_at_each_batch: bool = False, preemptive_threshold: float | None = None, num_threads: int | None = None, num_sub_threads: int = 1, cat_results: str | int | None = None, set_truncated: bool = False, use_buffers: bool | None = None, replay_buffer: ReplayBuffer | None = None, extend_buffer: bool = True, replay_buffer_chunk: bool | None = None, local_init_rb: bool | None = None, trust_policy: bool | None = None, compile_policy: bool | dict[str, Any] | None = None, cudagraph_policy: bool | dict[str, Any] | None = None, no_cuda_sync: bool = False, weight_updater: WeightUpdaterBase | Callable[[], WeightUpdaterBase] | None = None, weight_sync_schemes: dict[str, WeightSyncScheme] | None = None, weight_recv_schemes: dict[str, WeightSyncScheme] | None = None, track_policy_version: bool = False, worker_idx: int | None = None, ): self.closed = True self.worker_idx = worker_idx # Set up workers and environment functions create_env_fn, total_frames_per_batch = self._setup_workers_and_env_fns( create_env_fn, num_workers, frames_per_batch ) # Set up basic configuration self.set_truncated = set_truncated self.num_sub_threads = num_sub_threads self.num_threads = num_threads self.create_env_fn = create_env_fn self._read_compile_kwargs(compile_policy, cudagraph_policy) # Set up environment kwargs self.create_env_kwargs = self._setup_env_kwargs(create_env_kwargs) # Set up devices storing_devices, policy_devices, env_devices = self._get_devices( storing_device=storing_device, env_device=env_device, policy_device=policy_device, device=device, ) self.storing_device = storing_devices self.policy_device = policy_devices self.env_device = env_devices self.collector_class = collector_class del storing_device, env_device, policy_device, device self.no_cuda_sync = no_cuda_sync # Set up replay buffer self._use_buffers = use_buffers self.replay_buffer = replay_buffer self._setup_multi_replay_buffer( local_init_rb, replay_buffer, replay_buffer_chunk, extend_buffer ) # Set up policy and weights if trust_policy is None: trust_policy = policy is not None and isinstance(policy, CudaGraphModule) self.trust_policy = trust_policy policy_factory = self._setup_policy_factory(policy_factory) # Set up weight synchronization if weight_sync_schemes is None and weight_updater is None: weight_sync_schemes = {} elif weight_sync_schemes is not None and weight_updater is not None: raise TypeError( "Cannot specify both weight_sync_schemes and weight_updater." ) if ( weight_sync_schemes is not None and not weight_sync_schemes and weight_updater is None and (isinstance(policy, nn.Module) or any(policy_factory)) ): # Set up a default local shared-memory sync scheme for the policy. # This is used to propagate weights from the orchestrator policy # (possibly combined with a policy_factory) down to worker policies. weight_sync_schemes["policy"] = SharedMemWeightSyncScheme() self._setup_multi_weight_sync(weight_updater, weight_sync_schemes) # Store policy and policy_factory - temporary set to make them visible to the receiver self.policy = policy self.policy_factory = policy_factory # Set up weight receivers if provided if weight_recv_schemes is not None: self.register_scheme_receiver(weight_recv_schemes) self._setup_multi_policy_and_weights( self.policy, self.policy_factory, weight_updater, weight_sync_schemes ) # Set up policy version tracking self._setup_multi_policy_version_tracking(track_policy_version) # # Set up fallback policy for weight extraction # self._setup_fallback_policy(policy, policy_factory, weight_sync_schemes) # Set up total frames and other parameters self._setup_multi_total_frames( total_frames, total_frames_per_batch, frames_per_batch ) self.reset_at_each_iter = reset_at_each_iter self.postprocs = postproc self.max_frames_per_traj = ( int(max_frames_per_traj) if max_frames_per_traj is not None else 0 ) # Set up split trajectories self.requested_frames_per_batch = total_frames_per_batch self.reset_when_done = reset_when_done self._setup_split_trajs(split_trajs, reset_when_done) # Set up other parameters self.init_random_frames = ( int(init_random_frames) if init_random_frames is not None else 0 ) self.update_at_each_batch = update_at_each_batch self.exploration_type = exploration_type self.frames_per_worker = np.inf # Set up preemptive threshold self._setup_preemptive_threshold(preemptive_threshold) # Run worker processes try: self._run_processes() except Exception as e: self.shutdown(raise_on_error=False) raise e # Set up frame tracking and other options self._exclude_private_keys = True self._frames = 0 self._iter = -1 # Validate cat_results self._validate_cat_results(cat_results) def _setup_workers_and_env_fns( self, create_env_fn: Sequence[Callable] | Callable, num_workers: int | None, frames_per_batch: int | Sequence[int], ) -> tuple[list[Callable], int]: """Set up workers and environment functions.""" if isinstance(create_env_fn, Sequence): self.num_workers = len(create_env_fn) else: self.num_workers = num_workers create_env_fn = [create_env_fn] * self.num_workers if ( isinstance(frames_per_batch, Sequence) and len(frames_per_batch) != self.num_workers ): raise ValueError( "If `frames_per_batch` is provided as a sequence, it should contain exactly one value per worker." f"Got {len(frames_per_batch)} values for {self.num_workers} workers." ) self._frames_per_batch = frames_per_batch total_frames_per_batch = ( sum(frames_per_batch) if isinstance(frames_per_batch, Sequence) else frames_per_batch ) return create_env_fn, total_frames_per_batch def _setup_env_kwargs( self, create_env_kwargs: Sequence[dict] | dict | None ) -> list[dict]: """Set up environment kwargs for each worker.""" if isinstance(create_env_kwargs, Mapping): create_env_kwargs = [create_env_kwargs] * self.num_workers elif create_env_kwargs is None: create_env_kwargs = [{}] * self.num_workers elif isinstance(create_env_kwargs, (tuple, list)): create_env_kwargs = list(create_env_kwargs) if len(create_env_kwargs) != self.num_workers: raise ValueError( f"len(create_env_kwargs) must be equal to num_workers, got {len(create_env_kwargs)=} and {self.num_workers=}" ) return create_env_kwargs def _setup_multi_replay_buffer( self, local_init_rb: bool | None, replay_buffer: ReplayBuffer | None, replay_buffer_chunk: bool | None, extend_buffer: bool, ) -> None: """Set up replay buffer for multi-process collector.""" # Handle local_init_rb deprecation if local_init_rb is None: local_init_rb = False if replay_buffer is not None and not local_init_rb: warnings.warn( "local_init_rb=False is deprecated and will be removed in v0.12. " "The new storage-level initialization provides better performance.", FutureWarning, ) self.local_init_rb = local_init_rb self._check_replay_buffer_init() if replay_buffer_chunk is not None: if extend_buffer is None: replay_buffer_chunk = extend_buffer warnings.warn( "The replay_buffer_chunk is deprecated and replaced by extend_buffer. This argument will disappear in v0.10.", DeprecationWarning, ) elif extend_buffer != replay_buffer_chunk: raise ValueError( "conflicting values for replay_buffer_chunk and extend_buffer." ) self.extend_buffer = extend_buffer if ( replay_buffer is not None and hasattr(replay_buffer, "shared") and not replay_buffer.shared ): torchrl_logger.warning("Replay buffer is not shared. Sharing it.") replay_buffer.share() def _setup_policy_factory( self, policy_factory: Callable | list[Callable] | None ) -> list[Callable | None]: """Set up policy factory for each worker.""" if not isinstance(policy_factory, Sequence): policy_factory = [policy_factory] * self.num_workers return policy_factory def _setup_multi_policy_and_weights( self, policy: TensorDictModule | Callable | None, policy_factory: list[Callable | None], weight_updater: WeightUpdaterBase | Callable | None, weight_sync_schemes: dict[str, WeightSyncScheme] | None, ) -> None: """Set up policy for multi-process collector. With weight sync schemes: validates and stores policy without weight extraction. With weight updater: extracts weights and creates stateful policies. When both policy and policy_factory are provided (with weight_sync_schemes): - The policy is used ONLY for weight extraction via get_model() - The policy is NOT depopulated of its weights - The policy is NOT sent to workers - The policy_factory is used to create policies on workers """ if any(policy_factory) and policy is not None: if weight_sync_schemes is None: raise TypeError( "policy_factory and policy are mutually exclusive when not using weight_sync_schemes. " "When using weight_sync_schemes, policy can be provided alongside policy_factory " "for weight extraction purposes only (the policy will not be sent to workers)." ) # Store policy as fallback for weight extraction only # The policy keeps its weights and is NOT sent to workers self._fallback_policy = policy if weight_sync_schemes is not None: weight_sync_policy = weight_sync_schemes.get("policy") if weight_sync_policy is None: return # # If we only have a policy_factory (no policy instance), the scheme must # # be pre-initialized on the sender, since there is no policy on the # # collector to extract weights from. # if any(p is not None for p in policy_factory) and policy is None: # if not weight_sync_policy.initialized_on_sender: # raise RuntimeError( # "the weight sync scheme must be initialized on sender ahead of time " # "when passing a policy_factory without a policy instance on the collector. " # f"Got {policy_factory=}" # ) # # When a policy instance is provided alongside a policy_factory, the scheme # # can rely on the collector context (and its policy) to extract weights. # # Weight sync scheme initialization then happens in _run_processes where # # pipes and workers are available. else: # Using legacy weight updater - extract weights and create stateful policies self._setup_multi_policy_and_weights_legacy( policy, policy_factory, weight_updater, weight_sync_schemes ) def _setup_multi_policy_and_weights_legacy( self, policy: TensorDictModule | Callable | None, policy_factory: list[Callable | None], weight_updater: WeightUpdaterBase | Callable | None, weight_sync_schemes: dict[str, WeightSyncScheme] | None, ) -> None: """Set up policy and extract weights for each device. Creates stateful policies with weights extracted and placed in shared memory. Used with weight updater for in-place weight replacement. """ self._policy_weights_dict = {} self._fallback_policy = None # Policy to use for weight extraction fallback if not any(policy_factory): for policy_device, env_maker, env_maker_kwargs in _zip_strict( self.policy_device, self.create_env_fn, self.create_env_kwargs ): policy_new_device, get_weights_fn = self._get_policy_and_device( policy=policy, policy_device=policy_device, env_maker=env_maker, env_maker_kwargs=env_maker_kwargs, ) if type(policy_new_device) is not type(policy): policy = policy_new_device weights = ( TensorDict.from_module(policy_new_device) if isinstance(policy_new_device, nn.Module) else TensorDict() ) # For multi-process collectors, ensure weights are in shared memory if policy_device and policy_device.type == "cpu": weights = weights.share_memory_() self._policy_weights_dict[policy_device] = weights # Store the first policy instance for fallback weight extraction if self._fallback_policy is None: self._fallback_policy = policy_new_device self._get_weights_fn = get_weights_fn if weight_updater is None: # For multiprocessed collectors, use MultiProcessWeightSyncScheme by default if weight_sync_schemes is None: weight_sync_schemes = {"policy": MultiProcessWeightSyncScheme()} self._weight_sync_schemes = weight_sync_schemes elif weight_updater is None: warnings.warn( "weight_updater is None, but policy_factory is provided. This means that the server will " "not know how to send the weights to the workers. If the workers can handle their weight synchronization " "on their own (via some specialized worker type / constructor) this may well work, but make sure " "your weight synchronization strategy is properly set. To suppress this warning, you can use " "RemoteModuleWeightUpdater() which enforces explicit weight passing when calling update_policy_weights_(weights). " "This will work whenever your inference and training policies are nn.Module instances with similar structures." ) def _setup_multi_weight_sync( self, weight_updater: WeightUpdaterBase | Callable | None, weight_sync_schemes: dict[str, WeightSyncScheme] | None, ) -> None: """Set up weight synchronization for multi-process collector.""" if weight_sync_schemes is not None: # Use weight sync schemes for weight distribution self._weight_sync_schemes = weight_sync_schemes # Senders will be created in _run_processes self.weight_updater = None else: # Use weight updater for weight distribution self.weight_updater = weight_updater self._weight_sync_schemes = None def _setup_multi_policy_version_tracking( self, track_policy_version: bool | PolicyVersion ) -> None: """Set up policy version tracking for multi-process collector.""" self.policy_version_tracker = track_policy_version if PolicyVersion is not None: if isinstance(track_policy_version, bool) and track_policy_version: self.policy_version_tracker = PolicyVersion() elif hasattr(track_policy_version, "increment_version"): self.policy_version_tracker = track_policy_version else: self.policy_version_tracker = None else: if track_policy_version: raise ImportError( "PolicyVersion is not available. Please install the LLM dependencies or set track_policy_version=False." ) self.policy_version_tracker = None # TODO: Remove this def _setup_fallback_policy( self, policy: TensorDictModule | Callable | None, policy_factory: list[Callable | None], weight_sync_schemes: dict[str, WeightSyncScheme] | None, ) -> None: """Set up fallback policy for weight extraction when using policy_factory.""" # _fallback_policy is already set in _setup_multi_policy_and_weights if a policy was provided # If policy_factory was used, create a policy instance to use as fallback if policy is None and any(policy_factory) and weight_sync_schemes is not None: if not hasattr(self, "_fallback_policy") or self._fallback_policy is None: first_factory = ( policy_factory[0] if isinstance(policy_factory, list) else policy_factory ) if first_factory is not None: # Create a policy instance for weight extraction # This will be a reference to a policy with the same structure # For shared memory, modifications to any policy will be visible here self._fallback_policy = first_factory() def _setup_multi_total_frames( self, total_frames: int, total_frames_per_batch: int, frames_per_batch: int | Sequence[int], ) -> None: """Validate and set total frames for multi-process collector.""" if total_frames is None or total_frames < 0: total_frames = float("inf") else: remainder = total_frames % total_frames_per_batch if remainder != 0 and RL_WARNINGS: warnings.warn( f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({total_frames_per_batch}). " f"This means {total_frames_per_batch - remainder} additional frames will be collected. " "To silence this message, set the environment variable RL_WARNINGS to False." ) self.total_frames = ( int(total_frames) if total_frames != float("inf") else total_frames ) def _setup_split_trajs( self, split_trajs: bool | None, reset_when_done: bool ) -> None: """Set up split trajectories option.""" if split_trajs is None: split_trajs = False elif not reset_when_done and split_trajs: raise RuntimeError( "Cannot split trajectories when reset_when_done is False." ) self.split_trajs = split_trajs def _setup_preemptive_threshold(self, preemptive_threshold: float | None) -> None: """Set up preemptive threshold for early stopping.""" if preemptive_threshold is not None: if _is_osx: raise NotImplementedError( "Cannot use preemption on OSX due to Queue.qsize() not being implemented on this platform." ) self.preemptive_threshold = np.clip(preemptive_threshold, 0.0, 1.0) manager = _InterruptorManager() manager.start() self.interruptor = manager._Interruptor() else: self.preemptive_threshold = 1.0 self.interruptor = None def _validate_cat_results(self, cat_results: str | int | None) -> None: """Validate cat_results parameter.""" if cat_results is not None and ( not isinstance(cat_results, (int, str)) or (isinstance(cat_results, str) and cat_results != "stack") ): raise ValueError( "cat_results must be a string ('stack') " f"or an integer representing the cat dimension. Got {cat_results}." ) # Lazy import to avoid circular dependency from torchrl.collectors._multi_sync import MultiSyncCollector if not isinstance(self, MultiSyncCollector) and cat_results not in ( "stack", None, ): raise ValueError( "cat_results can only be used with ``MultiSyncCollector``." ) self.cat_results = cat_results def _check_replay_buffer_init(self): if self.replay_buffer is None: return is_init = hasattr(self.replay_buffer, "_storage") and getattr( self.replay_buffer._storage, "initialized", True ) if not is_init: if self.local_init_rb: # New behavior: storage handles all coordination itself # Nothing to do here - the storage will coordinate during first write self.replay_buffer.share() return # Legacy behavior: fake tensordict initialization if isinstance(self.create_env_fn[0], EnvCreator): fake_td = self.create_env_fn[0].meta_data.tensordict elif isinstance(self.create_env_fn[0], EnvBase): fake_td = self.create_env_fn[0].fake_tensordict() else: fake_td = self.create_env_fn[0]( **self.create_env_kwargs[0] ).fake_tensordict() fake_td["collector", "traj_ids"] = torch.zeros( fake_td.shape, dtype=torch.long ) # Use extend to avoid time-related transforms to fail self.replay_buffer.extend(fake_td.unsqueeze(-1)) self.replay_buffer.empty() @classmethod def _total_workers_from_env(cls, env_creators): if isinstance(env_creators, (tuple, list)): return sum( cls._total_workers_from_env(env_creator) for env_creator in env_creators ) from torchrl.envs import ParallelEnv if isinstance(env_creators, ParallelEnv): return env_creators.num_workers return 1 def _get_devices( self, *, storing_device: torch.device, policy_device: torch.device, env_device: torch.device, device: torch.device, ): # convert all devices to lists if not isinstance(storing_device, (list, tuple)): storing_device = [ storing_device, ] * self.num_workers if not isinstance(policy_device, (list, tuple)): policy_device = [ policy_device, ] * self.num_workers if not isinstance(env_device, (list, tuple)): env_device = [ env_device, ] * self.num_workers if not isinstance(device, (list, tuple)): device = [ device, ] * self.num_workers if not ( len(device) == len(storing_device) == len(policy_device) == len(env_device) == self.num_workers ): raise RuntimeError( f"THe length of the devices does not match the number of workers: {self.num_workers}." ) storing_device, policy_device, env_device = zip( *[ Collector._get_devices( storing_device=storing_device, policy_device=policy_device, env_device=env_device, device=device, ) for (storing_device, policy_device, env_device, device) in zip( storing_device, policy_device, env_device, device ) ] ) return storing_device, policy_device, env_device def frames_per_batch_worker(self, *, worker_idx: int | None = None) -> int: raise NotImplementedError @property def _queue_len(self) -> int: raise NotImplementedError def _run_processes(self) -> None: if self.num_threads is None: total_workers = self._total_workers_from_env(self.create_env_fn) self.num_threads = max( 1, torch.get_num_threads() - total_workers ) # 1 more thread for this proc # Set up for worker processes torch.set_num_threads(self.num_threads) ctx = _get_mp_ctx() # Best-effort global init (only if unset) to keep other mp users consistent. _set_mp_start_method_if_unset(ctx.get_start_method()) queue_out = ctx.Queue(self._queue_len) # sends data from proc to main self.procs = [] self._traj_pool = _TrajectoryPool(ctx=ctx, lock=True) # Create all pipes upfront (needed for weight sync scheme initialization) # Store as list of (parent, child) tuples for use in worker creation pipe_pairs = [ctx.Pipe() for _ in range(self.num_workers)] # Extract parent pipes for external use (e.g., polling, receiving messages) self.pipes = [pipe_parent for pipe_parent, _ in pipe_pairs] # Initialize all weight sync schemes now that pipes are available # Both SharedMemWeightSyncScheme (uses queues) and MultiProcessWeightSyncScheme (uses pipes) # can be initialized here since all required resources exist if self._weight_sync_schemes: for model_id, scheme in self._weight_sync_schemes.items(): if not scheme.initialized_on_sender: torchrl_logger.debug( f"Init scheme {type(scheme)} on sender side of {type(self)} with {model_id=} and model {_resolve_model(self, model_id)}." ) scheme.init_on_sender(model_id=model_id, context=self) # Create a policy on the right device policy_factory = self.policy_factory has_policy_factory = any(policy_factory) if has_policy_factory: policy_factory = [ CloudpickleWrapper(_policy_factory) for _policy_factory in policy_factory ] for i, (env_fun, env_fun_kwargs) in enumerate( zip(self.create_env_fn, self.create_env_kwargs) ): pipe_parent, pipe_child = pipe_pairs[i] # use pre-created pipes if env_fun.__class__.__name__ != "EnvCreator" and not isinstance( env_fun, EnvBase ): # to avoid circular imports env_fun = CloudpickleWrapper(env_fun) policy_device = self.policy_device[i] storing_device = self.storing_device[i] env_device = self.env_device[i] # Prepare policy for worker based on weight synchronization method. # IMPORTANT: when a policy_factory is provided, the policy instance # is used ONLY on the main process (for weight extraction etc.) and # is NOT sent to workers. policy = self.policy if self._weight_sync_schemes: # With weight sync schemes, send stateless policies. # Schemes handle weight distribution on worker side. if has_policy_factory: # Factory will create policy in worker; don't send policy. policy_to_send = None cm = contextlib.nullcontext() elif policy is not None: # Send policy with meta-device parameters (empty structure) - schemes apply weights policy_to_send = policy cm = _make_meta_policy(policy) else: policy_to_send = None cm = contextlib.nullcontext() elif hasattr(self, "_policy_weights_dict"): # LEGACY: # With weight updater, use in-place weight replacement. # Take the weights and locally dispatch them to the policy before sending. # This ensures a given set of shared weights for a device are shared # for all policies that rely on that device. policy_weights = self._policy_weights_dict.get(policy_device) if has_policy_factory: # Even in legacy mode, when a policy_factory is present, do not # send the stateful policy down to workers. policy_to_send = None cm = contextlib.nullcontext() else: policy_to_send = policy if policy is not None and policy_weights is not None: cm = policy_weights.to_module(policy) else: cm = contextlib.nullcontext() else: # Parameter-less policy. cm = contextlib.nullcontext() # When a policy_factory exists, never send the policy instance. policy_to_send = None if has_policy_factory else policy with cm: kwargs = { "policy_factory": policy_factory[i], "pipe_parent": pipe_parent, "pipe_child": pipe_child, "queue_out": queue_out, "create_env_fn": env_fun, "create_env_kwargs": env_fun_kwargs, "policy": policy_to_send, "max_frames_per_traj": self.max_frames_per_traj, "frames_per_batch": self.frames_per_batch_worker(worker_idx=i), "reset_at_each_iter": self.reset_at_each_iter, "policy_device": policy_device, "storing_device": storing_device, "env_device": env_device, "exploration_type": self.exploration_type, "reset_when_done": self.reset_when_done, "idx": i, "interruptor": self.interruptor, "set_truncated": self.set_truncated, "use_buffers": self._use_buffers, "replay_buffer": self.replay_buffer, "extend_buffer": self.extend_buffer, "traj_pool": self._traj_pool, "trust_policy": self.trust_policy, "compile_policy": self.compiled_policy_kwargs if self.compiled_policy else False, "cudagraph_policy": self.cudagraphed_policy_kwargs if self.cudagraphed_policy else False, "no_cuda_sync": self.no_cuda_sync, "collector_class": self.collector_class, "postproc": self.postprocs if self.replay_buffer is not None else None, "weight_sync_schemes": self._weight_sync_schemes, "worker_idx": i, # Worker index for queue-based weight distribution } proc = _ProcessNoWarn( target=_main_async_collector, num_threads=self.num_sub_threads, _start_method=ctx.get_start_method(), kwargs=kwargs, ) # proc.daemon can't be set as daemonic processes may be launched by the process itself try: proc.start() except TypeError as err: if "cannot pickle" in str(err): raise RuntimeError( "A non-serializable object was passed to the collector workers." ) from err except RuntimeError as err: if "Cowardly refusing to serialize non-leaf tensor" in str(err): raise RuntimeError( "At least one of the tensors in the policy, replay buffer, environment constructor or postprocessor requires gradients. " "This is not supported in multiprocessed data collectors.\n- For ReplayBuffer transforms, use a `transform_factory` instead with `delayed_init=True`.\n" "- Make sure your environment constructor does not reference tensors already instantiated on the main process.\n" "- Since no gradient can be propagated through the Collector pipes, the backward graph is never needed. Consider using detached tensors instead." ) from err elif "_share_fd_: only available on CPU" in str( err ) or "_share_filename_: only available on CPU" in str(err): # This is a common failure mode on older PyTorch versions when using the # "spawn" multiprocessing start method: the process object contains a # CUDA/MPS tensor (or a module/buffer on a non-CPU device), which must be # pickled when spawning workers. # # See: https://github.com/pytorch/pytorch/issues/87688#issuecomment-1968901877 start_method = None try: start_method = mp.get_start_method(allow_none=True) except Exception: # Best effort: some environments may disallow querying here. start_method = None raise RuntimeError( "Failed to start a collector worker process because a non-CPU tensor " "was captured in the worker process arguments and had to be serialized " "(pickled) at process start.\n\n" f"Detected multiprocessing start method: {start_method!r}.\n\n" "Workarounds:\n" "- Keep any tensors/modules referenced by your collector constructor " "(policy, replay buffer, postprocs, env factory captures, etc.) on CPU " "when using a spawning start method (common on macOS/Windows).\n" "- Or set the multiprocessing start method to 'fork' *before* creating " "the collector (Unix only). Example:\n\n" " import torch.multiprocessing as mp\n" " if __name__ == '__main__':\n" " mp.set_start_method('fork', force=True)\n\n" "Upstream context: https://github.com/pytorch/pytorch/issues/87688#issuecomment-1968901877" ) from err else: raise err except _pickle.PicklingError as err: if "<lambda>" in str(err): raise RuntimeError( """Can't open a process with doubly cloud-pickled lambda function. This error is likely due to an attempt to use a ParallelEnv in a multiprocessed data collector. To do this, consider wrapping your lambda function in an `torchrl.envs.EnvCreator` wrapper as follows: `env = ParallelEnv(N, EnvCreator(my_lambda_function))`. This will not only ensure that your lambda function is cloud-pickled once, but also that the state dict is synchronised across processes if needed.""" ) from err pipe_child.close() self.procs.append(proc) # Synchronize initial weights with workers AFTER starting processes but BEFORE waiting for "instantiated" # This must happen after proc.start() but before workers send "instantiated" to avoid deadlock: # Workers will call receiver.collect() during init and may block waiting for data if self._weight_sync_schemes: # start with policy policy_scheme = self._weight_sync_schemes.get("policy") if policy_scheme is not None: policy_scheme.connect() for key, scheme in self._weight_sync_schemes.items(): if key == "policy": continue scheme.connect() # Wait for workers to be ready for i, pipe_parent in enumerate(self.pipes): pipe_parent.poll(timeout=INSTANTIATE_TIMEOUT) try: msg = pipe_parent.recv() except EOFError as e: raise RuntimeError( f"Worker {i} failed to initialize and closed the connection before sending status. " f"This typically indicates that the worker process crashed during initialization. " f"Check the worker process logs for the actual error." ) from e if msg != "instantiated": # Check if it's an error dict from worker if isinstance(msg, dict) and msg.get("error"): # Reconstruct the exception from the worker exc_type_name = msg["exception_type"] exc_msg = msg["exception_msg"] traceback_str = msg["traceback"] # Try to get the actual exception class exc_class = None exc_module = msg["exception_module"] if exc_module == "builtins": # Get from builtins import builtins exc_class = getattr(builtins, exc_type_name, None) else: # Try to import from the module try: import importlib mod = importlib.import_module(exc_module) exc_class = getattr(mod, exc_type_name, None) except Exception: pass # Re-raise with original exception type if possible if exc_class is not None: raise exc_class( f"{exc_msg}\n\nWorker traceback:\n{traceback_str}" ) else: # Fall back to RuntimeError if we can't get the original type raise RuntimeError( f"Worker {i} raised {exc_type_name}: {exc_msg}\n\nWorker traceback:\n{traceback_str}" ) else: # Legacy string error message raise RuntimeError(msg) self.queue_out = queue_out self.closed = False _running_free = False
[docs] def start(self): """Starts the collector(s) for asynchronous data collection. The collected data is stored in the provided replay buffer. This method initiates the background collection of data across multiple processes, allowing for decoupling of data collection and training. Raises: RuntimeError: If no replay buffer is defined during the collector's initialization. Example: >>> from torchrl.modules import RandomPolicy >>> >>> import time >>> from functools import partial >>> >>> import tqdm >>> >>> from torchrl.collectors import MultiAsyncCollector >>> from torchrl.data import LazyTensorStorage, ReplayBuffer >>> from torchrl.envs import GymEnv, set_gym_backend >>> import ale_py >>> >>> # Set the gym backend to gymnasium >>> set_gym_backend("gymnasium").set() >>> >>> if __name__ == "__main__": ... # Create a random policy for the Pong environment ... env_fn = partial(GymEnv, "ALE/Pong-v5") ... policy = RandomPolicy(env_fn().action_spec) ... ... # Initialize a shared replay buffer ... rb = ReplayBuffer(storage=LazyTensorStorage(10000), shared=True) ... ... # Create a multi-async data collector with 16 environments ... num_envs = 16 ... collector = MultiAsyncCollector( ... [env_fn] * num_envs, ... policy=policy, ... replay_buffer=rb, ... frames_per_batch=num_envs * 16, ... total_frames=-1, ... ) ... ... # Progress bar to track the number of collected frames ... pbar = tqdm.tqdm(total=100_000) ... ... # Start the collector asynchronously ... collector.start() ... ... # Track the write count of the replay buffer ... prec_wc = 0 ... while True: ... wc = rb.write_count ... c = wc - prec_wc ... prec_wc = wc ... ... # Update the progress bar ... pbar.update(c) ... pbar.set_description(f"Write Count: {rb.write_count}") ... ... # Check the write count every 0.5 seconds ... time.sleep(0.5) ... ... # Stop when the desired number of frames is reached ... if rb.write_count . 100_000: ... break ... ... # Shut down the collector ... collector.async_shutdown() """ if self.replay_buffer is None: raise RuntimeError("Replay buffer must be defined for execution.") if self.init_random_frames is not None and self.init_random_frames > 0: raise RuntimeError( "Cannot currently start() a collector that requires random frames. Please submit a feature request on github." ) self._running_free = True for pipe in self.pipes: pipe.send((None, "run_free"))
[docs] @contextlib.contextmanager def pause(self): """Context manager that pauses the collector if it is running free.""" if self._running_free: for pipe in self.pipes: pipe.send((None, "pause")) # Make sure all workers are paused for _ in self.pipes: idx, msg = self.queue_out.get() if msg != "paused": raise ValueError(f"Expected paused, but got {msg=}.") torchrl_logger.debug(f"Worker {idx} is paused.") self._running_free = False yield None for pipe in self.pipes: pipe.send((None, "restart")) self._running_free = True else: raise RuntimeError("Collector cannot be paused.")
def __del__(self): try: self.shutdown() except Exception: # an AttributeError will typically be raised if the collector is deleted when the program ends. # In the future, insignificant changes to the close method may change the error type. # We excplicitely assume that any error raised during closure in # __del__ will not affect the program. pass
[docs] def shutdown( self, timeout: float | None = None, close_env: bool = True, raise_on_error: bool = True, ) -> None: """Shuts down all processes. This operation is irreversible. Args: timeout (float, optional): The timeout for closing pipes between workers. close_env (bool, optional): Whether to close the environment. Defaults to `True`. raise_on_error (bool, optional): Whether to raise an error if the shutdown fails. Defaults to `True`. """ if not close_env: raise RuntimeError( f"Cannot shutdown {type(self).__name__} collector without environment being closed." ) try: self._shutdown_main(timeout) except Exception as e: if raise_on_error: raise e else: pass
def _shutdown_main(self, timeout: float | None = None) -> None: if timeout is None: timeout = 10 try: if self.closed: return _check_for_faulty_process(self.procs) all_closed = [False] * self.num_workers rep = 0 for idx in range(self.num_workers): if all_closed[idx]: continue if not self.procs[idx].is_alive(): continue self.pipes[idx].send((None, "close")) while not all(all_closed) and rep < 1000: rep += 1 for idx in range(self.num_workers): if all_closed[idx]: continue if not self.procs[idx].is_alive(): all_closed[idx] = True continue try: if self.pipes[idx].poll(timeout / 1000 / self.num_workers): msg = self.pipes[idx].recv() if msg != "closed": raise RuntimeError(f"got {msg} but expected 'close'") all_closed[idx] = True else: continue except BrokenPipeError: all_closed[idx] = True continue self.closed = True self.queue_out.close() for pipe in self.pipes: pipe.close() for proc in self.procs: proc.join(1.0) finally: import torchrl num_threads = min( torchrl._THREAD_POOL_INIT, torch.get_num_threads() + self._total_workers_from_env(self.create_env_fn), ) torch.set_num_threads(num_threads) for proc in self.procs: if proc.is_alive(): proc.terminate()
[docs] def async_shutdown(self, timeout: float | None = None): return self.shutdown(timeout=timeout)
[docs] def set_seed(self, seed: int, static_seed: bool = False) -> int: """Sets the seeds of the environments stored in the DataCollector. Args: seed: integer representing the seed to be used for the environment. static_seed (bool, optional): if ``True``, the seed is not incremented. Defaults to False Returns: Output seed. This is useful when more than one environment is contained in the DataCollector, as the seed will be incremented for each of these. The resulting seed is the seed of the last environment. Examples: >>> from torchrl.envs import ParallelEnv >>> from torchrl.envs.libs.gym import GymEnv >>> from tensordict.nn import TensorDictModule >>> from torch import nn >>> env_fn = lambda: GymEnv("Pendulum-v1") >>> env_fn_parallel = lambda: ParallelEnv(6, env_fn) >>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) >>> collector = Collector(env_fn_parallel, policy, frames_per_batch=100, total_frames=300) >>> out_seed = collector.set_seed(1) # out_seed = 6 """ _check_for_faulty_process(self.procs) for idx in range(self.num_workers): self.pipes[idx].send(((seed, static_seed), "seed")) new_seed, msg = self.pipes[idx].recv() if msg != "seeded": raise RuntimeError(f"Expected msg='seeded', got {msg}") seed = new_seed self.reset() return seed
[docs] def reset(self, reset_idx: Sequence[bool] | None = None) -> None: """Resets the environments to a new initial state. Args: reset_idx: Optional. Sequence indicating which environments have to be reset. If None, all environments are reset. """ _check_for_faulty_process(self.procs) if reset_idx is None: reset_idx = [True for _ in range(self.num_workers)] for idx in range(self.num_workers): if reset_idx[idx]: self.pipes[idx].send((None, "reset")) for idx in range(self.num_workers): if reset_idx[idx]: j, msg = self.pipes[idx].recv() if msg != "reset": raise RuntimeError(f"Expected msg='reset', got {msg}")
[docs] def state_dict(self) -> OrderedDict: """Returns the state_dict of the data collector. Each field represents a worker containing its own state_dict. """ for idx in range(self.num_workers): self.pipes[idx].send((None, "state_dict")) state_dict = OrderedDict() for idx in range(self.num_workers): _state_dict, msg = self.pipes[idx].recv() if msg != "state_dict": raise RuntimeError(f"Expected msg='state_dict', got {msg}") state_dict[f"worker{idx}"] = _state_dict state_dict.update({"frames": self._frames, "iter": self._iter}) return state_dict
[docs] def load_state_dict(self, state_dict: OrderedDict) -> None: """Loads the state_dict on the workers. Args: state_dict (OrderedDict): state_dict of the form ``{"worker0": state_dict0, "worker1": state_dict1}``. """ for idx in range(self.num_workers): self.pipes[idx].send((state_dict[f"worker{idx}"], "load_state_dict")) for idx in range(self.num_workers): _, msg = self.pipes[idx].recv() if msg != "loaded": raise RuntimeError(f"Expected msg='loaded', got {msg}") self._frames = state_dict["frames"] self._iter = state_dict["iter"]
[docs] def increment_version(self): """Increment the policy version.""" if self.policy_version_tracker is not None: if not hasattr(self.policy_version_tracker, "increment_version"): raise RuntimeError( "Policy version tracker is not a PolicyVersion instance. Please pass a PolicyVersion instance to the collector." ) self.policy_version_tracker.increment_version()
@property def policy_version(self) -> str | int | None: """The current policy version.""" if not hasattr(self.policy_version_tracker, "version"): return None return self.policy_version_tracker.version
[docs] def get_policy_version(self) -> str | int | None: """Get the current policy version. This method exists to support remote calls in Ray actors, since properties cannot be accessed directly through Ray's RPC mechanism. Returns: The current version number (int) or UUID (str), or None if version tracking is disabled. """ return self.policy_version
[docs] def getattr_policy(self, attr): """Get an attribute from the policy of the first worker. Args: attr (str): The attribute name to retrieve from the policy. Returns: The attribute value from the policy of the first worker. Raises: AttributeError: If the attribute doesn't exist on the policy. """ _check_for_faulty_process(self.procs) # Send command to first worker (index 0) self.pipes[0].send((attr, "getattr_policy")) result, msg = self.pipes[0].recv() if msg != "getattr_policy": raise RuntimeError(f"Expected msg='getattr_policy', got {msg}") # If the worker returned an AttributeError, re-raise it if isinstance(result, AttributeError): raise result return result
[docs] def getattr_env(self, attr): """Get an attribute from the environment of the first worker. Args: attr (str): The attribute name to retrieve from the environment. Returns: The attribute value from the environment of the first worker. Raises: AttributeError: If the attribute doesn't exist on the environment. """ _check_for_faulty_process(self.procs) # Send command to first worker (index 0) self.pipes[0].send((attr, "getattr_env")) result, msg = self.pipes[0].recv() if msg != "getattr_env": raise RuntimeError(f"Expected msg='getattr_env', got {msg}") # If the worker returned an AttributeError, re-raise it if isinstance(result, AttributeError): raise result return result
[docs] def getattr_rb(self, attr): """Get an attribute from the replay buffer.""" return getattr(self.replay_buffer, attr)
[docs] def get_model(self, model_id: str): """Get model instance by ID (for weight sync schemes). Args: model_id: Model identifier (e.g., "policy", "value_net") Returns: The model instance Raises: ValueError: If model_id is not recognized """ if model_id == "policy": # Return the fallback policy instance if (fallback_policy := getattr(self, "_fallback_policy", None)) is not None: return fallback_policy elif hasattr(self, "policy") and self.policy is not None: return self.policy else: raise ValueError(f"No policy found for model_id '{model_id}'") else: # Try to resolve via attribute access return _resolve_model(self, model_id)
[docs] def get_cached_weights(self, model_id: str): """Get cached shared memory weights if available (for weight sync schemes). Args: model_id: Model identifier Returns: Cached TensorDict weights or None if not available """ if model_id == "policy" and hasattr(self, "_policy_weights_dict"): # Get the policy device (first device if list) policy_device = self.policy_device if isinstance(policy_device, (list, tuple)): policy_device = policy_device[0] if len(policy_device) > 0 else None # Return cached weights for this device return self._policy_weights_dict.get(policy_device) return None
def _weight_update_impl( self, policy_or_weights: TensorDictBase | nn.Module | dict | None = None, *, worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, model_id: str | None = None, weights_dict: dict[str, Any] | None = None, **kwargs, ) -> None: """Update weights on workers. Weight sync schemes now use background threads on the receiver side. The scheme's send() method: 1. Puts weights in the queue (or updates shared memory) 2. Sends a "receive" instruction to the worker's background thread 3. Waits for acknowledgment (if sync=True) No pipe signaling is needed - the scheme handles everything internally. """ # Call parent implementation which calls scheme.send() # The scheme handles instruction delivery and acknowledgments super()._weight_update_impl( policy_or_weights=policy_or_weights, worker_ids=worker_ids, model_id=model_id, weights_dict=weights_dict, **kwargs, ) # for RPC
[docs] def receive_weights(self, policy_or_weights: TensorDictBase | None = None): return super().receive_weights(policy_or_weights)
# for RPC def _receive_weights_scheme(self): return super()._receive_weights_scheme()
# Backward-compatible alias (deprecated, use MultiCollector instead) MultiCollector = MultiCollector

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources