Shortcuts

Source code for torchrl.collectors._single

from __future__ import annotations

import contextlib
import threading
import warnings
import weakref
from collections import OrderedDict
from collections.abc import Callable, Iterator, Sequence
from textwrap import indent
from typing import Any

import torch

from tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase
from tensordict.nn import CudaGraphModule, TensorDictModule, TensorDictModuleBase
from torch import nn
from torchrl import compile_with_warmup, logger as torchrl_logger
from torchrl._utils import (
    _ends_with,
    _make_ordinal_device,
    _replace_last,
    accept_remote_rref_udf_invocation,
    prod,
    RL_WARNINGS,
)
from torchrl.collectors._base import _LegacyCollectorMeta, BaseCollector
from torchrl.collectors._constants import (
    cudagraph_mark_step_begin,
    DEFAULT_EXPLORATION_TYPE,
    ExplorationType,
)
from torchrl.collectors.utils import _TrajectoryPool, split_trajectories
from torchrl.collectors.weight_update import WeightUpdaterBase
from torchrl.data import ReplayBuffer
from torchrl.data.utils import DEVICE_TYPING
from torchrl.envs import EnvBase, EnvCreator, StepCounter, TransformedEnv
from torchrl.envs.common import _do_nothing
from torchrl.envs.llm.transforms import PolicyVersion
from torchrl.envs.utils import (
    _aggregate_end_of_traj,
    _make_compatible_policy,
    set_exploration_type,
)
from torchrl.modules import RandomPolicy
from torchrl.weight_update import WeightSyncScheme
from torchrl.weight_update.utils import _resolve_model


[docs]@accept_remote_rref_udf_invocation class Collector(BaseCollector): """Generic data collector for RL problems. Requires an environment constructor and a policy. Args: create_env_fn (Callable or EnvBase): a callable that returns an instance of :class:`~torchrl.envs.EnvBase` class, or the env itself. policy (Callable): Policy to be executed in the environment. Must accept :class:`tensordict.tensordict.TensorDictBase` object as input. If ``None`` is provided, the policy used will be a :class:`~torchrl.collectors.RandomPolicy` instance with the environment ``action_spec``. Accepted policies are usually subclasses of :class:`~tensordict.nn.TensorDictModuleBase`. This is the recommended usage of the collector. Other callables are accepted too: If the policy is not a ``TensorDictModuleBase`` (e.g., a regular :class:`~torch.nn.Module` instances) it will be wrapped in a `nn.Module` first. Then, the collector will try to assess if these modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not. - If the policy forward signature matches any of ``forward(self, tensordict)``, ``forward(self, td)`` or ``forward(self, <anything>: TensorDictBase)`` (or any typing with a single argument typed as a subclass of ``TensorDictBase``) then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`. - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``. .. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized / pickled directly), the ``policy_factory`` should be used instead. Keyword Args: policy_factory (Callable[[], Callable], optional): a callable that returns a policy instance. This is exclusive with the `policy` argument. .. note:: `policy_factory` comes in handy whenever the policy cannot be serialized. frames_per_batch (int): A keyword-only argument representing the total number of elements in a batch. total_frames (int): A keyword-only argument representing the total number of frames returned by the collector during its lifespan. If the ``total_frames`` is not divisible by ``frames_per_batch``, an exception is raised. Endless collectors can be created by passing ``total_frames=-1``. Defaults to ``-1`` (endless collector). device (int, str or torch.device, optional): The generic device of the collector. The ``device`` args fills any non-specified device: if ``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or ``env_device`` is not specified, its value will be set to ``device``. Defaults to ``None`` (No default device). storing_device (int, str or torch.device, optional): The device on which the output :class:`~tensordict.TensorDict` will be stored. If ``device`` is passed and ``storing_device`` is ``None``, it will default to the value indicated by ``device``. For long trajectories, it may be necessary to store the data on a different device than the one where the policy and env are executed. Defaults to ``None`` (the output tensordict isn't on a specific device, leaf tensors sit on the device where they were created). env_device (int, str or torch.device, optional): The device on which the environment should be cast (or executed if that functionality is supported). If not specified and the env has a non-``None`` device, ``env_device`` will default to that value. If ``device`` is passed and ``env_device=None``, it will default to ``device``. If the value as such specified of ``env_device`` differs from ``policy_device`` and one of them is not ``None``, the data will be cast to ``env_device`` before being passed to the env (i.e., passing different devices to policy and env is supported). Defaults to ``None``. policy_device (int, str or torch.device, optional): The device on which the policy should be cast. If ``device`` is passed and ``policy_device=None``, it will default to ``device``. If the value as such specified of ``policy_device`` differs from ``env_device`` and one of them is not ``None``, the data will be cast to ``policy_device`` before being passed to the policy (i.e., passing different devices to policy and env is supported). Defaults to ``None``. create_env_kwargs (dict, optional): Dictionary of kwargs for ``create_env_fn``. max_frames_per_traj (int, optional): Maximum steps per trajectory. Note that a trajectory can span across multiple batches (unless ``reset_at_each_iter`` is set to ``True``, see below). Once a trajectory reaches ``n_steps``, the environment is reset. If the environment wraps multiple environments together, the number of steps is tracked for each environment independently. Negative values are allowed, in which case this argument is ignored. Defaults to ``None`` (i.e., no maximum number of steps). init_random_frames (int, optional): Number of frames for which the policy is ignored before it is called. This feature is mainly intended to be used in offline/model-based settings, where a batch of random trajectories can be used to initialize training. If provided, it will be rounded up to the closest multiple of frames_per_batch. Defaults to ``None`` (i.e. no random frames). reset_at_each_iter (bool, optional): Whether environments should be reset at the beginning of a batch collection. Defaults to ``False``. postproc (Callable, optional): A post-processing transform, such as a :class:`~torchrl.envs.Transform` or a :class:`~torchrl.data.postprocs.MultiStep` instance. .. warning:: Postproc is not applied when a replay buffer is used and items are added to the buffer as they are produced (`extend_buffer=False`). The recommended usage is to use `extend_buffer=True`. Defaults to ``None``. split_trajs (bool, optional): Boolean indicating whether the resulting TensorDict should be split according to the trajectories. See :func:`~torchrl.collectors.utils.split_trajectories` for more information. Defaults to ``False``. exploration_type (ExplorationType, optional): interaction mode to be used when collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``, ``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE`` or ``torchrl.envs.utils.ExplorationType.MEAN``. return_same_td (bool, optional): if ``True``, the same TensorDict will be returned at each iteration, with its values updated. This feature should be used cautiously: if the same tensordict is added to a replay buffer for instance, the whole content of the buffer will be identical. Default is ``False``. interruptor (_Interruptor, optional): An _Interruptor object that can be used from outside the class to control rollout collection. The _Interruptor class has methods ´start_collection´ and ´stop_collection´, which allow to implement strategies such as preeptively stopping rollout collection. Default is ``False``. set_truncated (bool, optional): if ``True``, the truncated signals (and corresponding ``"done"`` but not ``"terminated"``) will be set to ``True`` when the last frame of a rollout is reached. If no ``"truncated"`` key is found, an exception is raised. Truncated keys can be set through ``env.add_truncated_keys``. Defaults to ``False``. use_buffers (bool, optional): if ``True``, a buffer will be used to stack the data. This isn't compatible with environments with dynamic specs. Defaults to ``True`` for envs without dynamic specs, ``False`` for others. replay_buffer (ReplayBuffer, optional): if provided, the collector will not yield tensordicts but populate the buffer instead. Defaults to ``None``. .. seealso:: By default (``extend_buffer=True``), the buffer is extended with entire rollouts. If the buffer needs to be populated with individual frames as they are collected, set ``extend_buffer=False`` (deprecated). .. warning:: Using a replay buffer with a `postproc` or `split_trajs=True` requires `extend_buffer=True`, as the whole batch needs to be observed to apply these transforms. extend_buffer (bool, optional): if `True`, the replay buffer is extended with entire rollouts and not with single steps. Defaults to `True`. .. note:: Setting this to `False` is deprecated and will be removed in a future version. Extending the buffer with entire rollouts is the recommended approach for better compatibility with postprocessing and trajectory splitting. trust_policy (bool, optional): if ``True``, a non-TensorDictModule policy will be trusted to be assumed to be compatible with the collector. This defaults to ``True`` for CudaGraphModules and ``False`` otherwise. compile_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be compiled using :func:`~torch.compile` default behaviour. If a dictionary of kwargs is passed, it will be used to compile the policy. cudagraph_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be wrapped in :class:`~tensordict.nn.CudaGraphModule` with default kwargs. If a dictionary of kwargs is passed, it will be used to wrap the policy. no_cuda_sync (bool): if ``True``, explicit CUDA synchronizations calls will be bypassed. For environments running directly on CUDA (`IsaacLab <https://github.com/isaac-sim/IsaacLab/>`_ or `ManiSkills <https://github.com/haosulab/ManiSkill/>`_) cuda synchronization may cause unexpected crashes. Defaults to ``False``. weight_updater (WeightUpdaterBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdaterBase` or its subclass, responsible for updating the policy weights on remote inference workers. This is typically not used in :class:`~torchrl.collectors.Collector` as it operates in a single-process environment. Consider using a constructor if the updater needs to be serialized. weight_sync_schemes (dict[str, WeightSyncScheme], optional): **Not supported for Collector**. Collector is a leaf collector and cannot send weights to sub-collectors. Providing this parameter will raise a ValueError. Use ``weight_recv_schemes`` if you need to receive weights from a parent collector. weight_recv_schemes (dict[str, WeightSyncScheme], optional): Dictionary of weight sync schemes for RECEIVING weights from parent collectors. Keys are model identifiers (e.g., "policy") and values are WeightSyncScheme instances configured to receive weights. This enables cascading weight updates in hierarchies like: RPCDataCollector -> MultiSyncCollector -> Collector. Defaults to ``None``. track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy. This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment. Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track the policy version. Defaults to `False`. Examples: >>> from torchrl.envs.libs.gym import GymEnv >>> from tensordict.nn import TensorDictModule >>> from torch import nn >>> env_maker = lambda: GymEnv("Pendulum-v1", device="cpu") >>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) >>> collector = Collector( ... create_env_fn=env_maker, ... policy=policy, ... total_frames=2000, ... max_frames_per_traj=50, ... frames_per_batch=200, ... init_random_frames=-1, ... reset_at_each_iter=False, ... device="cpu", ... storing_device="cpu", ... ) >>> for i, data in enumerate(collector): ... if i == 2: ... print(data) ... break TensorDict( fields={ action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), collector: TensorDict( fields={ traj_ids: Tensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([200]), device=cpu, is_shared=False), done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([200]), device=cpu, is_shared=False), observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([200]), device=cpu, is_shared=False) >>> del collector The collector delivers batches of data that are marked with a ``"time"`` dimension. Examples: >>> assert data.names[-1] == "time" """ _ignore_rb: bool = False def __init__( self, create_env_fn: ( EnvBase | EnvCreator | Sequence[Callable[[], EnvBase]] # noqa: F821 ), # noqa: F821 policy: None | (TensorDictModule | Callable[[TensorDictBase], TensorDictBase]) = None, *, policy_factory: Callable[[], Callable] | None = None, frames_per_batch: int, total_frames: int = -1, device: DEVICE_TYPING | None = None, storing_device: DEVICE_TYPING | None = None, policy_device: DEVICE_TYPING | None = None, env_device: DEVICE_TYPING | None = None, create_env_kwargs: dict[str, Any] | None = None, max_frames_per_traj: int | None = None, init_random_frames: int | None = None, reset_at_each_iter: bool = False, postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, split_trajs: bool | None = None, exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, return_same_td: bool = False, reset_when_done: bool = True, interruptor=None, set_truncated: bool = False, use_buffers: bool | None = None, replay_buffer: ReplayBuffer | None = None, extend_buffer: bool = True, local_init_rb: bool | None = None, trust_policy: bool | None = None, compile_policy: bool | dict[str, Any] | None = None, cudagraph_policy: bool | dict[str, Any] | None = None, no_cuda_sync: bool = False, weight_updater: WeightUpdaterBase | Callable[[], WeightUpdaterBase] | None = None, weight_sync_schemes: dict[str, WeightSyncScheme] | None = None, weight_recv_schemes: dict[str, WeightSyncScheme] | None = None, track_policy_version: bool = False, worker_idx: int | None = None, **kwargs, ): self.closed = True self.worker_idx = worker_idx # Note: weight_sync_schemes can be used to send weights to components # within the environment (e.g., RayModuleTransform), not just sub-collectors # Initialize environment env = self._init_env(create_env_fn, create_env_kwargs) # Initialize policy policy = self._init_policy(policy, policy_factory, env, trust_policy) self._read_compile_kwargs(compile_policy, cudagraph_policy) # Handle trajectory pool and validate kwargs self._traj_pool_val = kwargs.pop("traj_pool", None) if kwargs: raise TypeError( f"Keys {list(kwargs.keys())} are unknown to {type(self).__name__}." ) # Set up devices and synchronization self._setup_devices( device=device, storing_device=storing_device, policy_device=policy_device, env_device=env_device, no_cuda_sync=no_cuda_sync, ) self.env: EnvBase = env del env # Set up policy version tracking self._setup_policy_version_tracking(track_policy_version) # Set up replay buffer self._setup_replay_buffer( replay_buffer=replay_buffer, extend_buffer=extend_buffer, local_init_rb=local_init_rb, postproc=postproc, split_trajs=split_trajs, return_same_td=return_same_td, use_buffers=use_buffers, ) self.closed = False # Validate reset_when_done if not reset_when_done: raise ValueError("reset_when_done is deprecated.") self.reset_when_done = reset_when_done self.n_env = self.env.batch_size.numel() # Register collector with policy and env if hasattr(policy, "register_collector"): policy.register_collector(self) if hasattr(self.env, "register_collector"): self.env.register_collector(self) # Set up policy and weights self._setup_policy_and_weights(policy) # Apply environment device self._apply_env_device() # Set up max frames per trajectory self._setup_max_frames_per_traj(max_frames_per_traj) # Validate and set total frames self.reset_at_each_iter = reset_at_each_iter self._setup_total_frames(total_frames, frames_per_batch) # Set up init random frames self._setup_init_random_frames(init_random_frames, frames_per_batch) # Set up postproc self._setup_postproc(postproc) # Calculate frames per batch self._setup_frames_per_batch(frames_per_batch) # Set exploration and other options self.exploration_type = ( exploration_type if exploration_type else DEFAULT_EXPLORATION_TYPE ) self.return_same_td = return_same_td self.set_truncated = set_truncated # Create shuttle and rollout buffers self._make_shuttle() self._maybe_make_final_rollout(make_rollout=self._use_buffers) self._set_truncated_keys() # Set split trajectories option if split_trajs is None: split_trajs = False self.split_trajs = split_trajs self._exclude_private_keys = True # Set up interruptor and frame tracking self.interruptor = interruptor self._frames = 0 self._iter = -1 # Set up weight synchronization self._setup_weight_sync(weight_updater, weight_sync_schemes) # Set up weight receivers if provided if weight_recv_schemes is not None: self.register_scheme_receiver(weight_recv_schemes) def _init_env( self, create_env_fn: EnvBase | EnvCreator | Callable[[], EnvBase], create_env_kwargs: dict[str, Any] | None, ) -> EnvBase: """Initialize and configure the environment.""" from torchrl.envs.batched_envs import BatchedEnvBase if create_env_kwargs is None: create_env_kwargs = {} if not isinstance(create_env_fn, EnvBase): env = create_env_fn(**create_env_kwargs) else: env = create_env_fn if create_env_kwargs: if not isinstance(env, BatchedEnvBase): raise RuntimeError( "kwargs were passed to Collector but they can't be set " f"on environment of type {type(create_env_fn)}." ) env.update_kwargs(create_env_kwargs) return env def _init_policy( self, policy: TensorDictModule | Callable | None, policy_factory: Callable[[], Callable] | None, env: EnvBase, trust_policy: bool | None, ) -> TensorDictModule | Callable: """Initialize and configure the policy before device placement / wrapping.""" if policy is None: if policy_factory is not None: policy = policy_factory() else: policy = RandomPolicy(env.full_action_spec) elif policy_factory is not None: raise TypeError("policy_factory cannot be used with policy argument.") if trust_policy is None: trust_policy = isinstance(policy, (RandomPolicy, CudaGraphModule)) self.trust_policy = trust_policy return policy def _setup_devices( self, device: DEVICE_TYPING | None, storing_device: DEVICE_TYPING | None, policy_device: DEVICE_TYPING | None, env_device: DEVICE_TYPING | None, no_cuda_sync: bool, ) -> None: """Set up devices and synchronization functions.""" storing_device, policy_device, env_device = self._get_devices( storing_device=storing_device, policy_device=policy_device, env_device=env_device, device=device, ) self.storing_device = storing_device self._sync_storage = self._get_sync_fn(storing_device) self.env_device = env_device self._sync_env = self._get_sync_fn(env_device) self.policy_device = policy_device self._sync_policy = self._get_sync_fn(policy_device) self.device = device self.no_cuda_sync = no_cuda_sync self._cast_to_policy_device = self.policy_device != self.env_device def _get_sync_fn(self, device: torch.device | None) -> Callable: """Get the appropriate synchronization function for a device.""" if device is not None and device.type != "cuda": # Cuda handles sync if torch.cuda.is_available(): return torch.cuda.synchronize elif torch.backends.mps.is_available() and hasattr(torch, "mps"): return torch.mps.synchronize elif hasattr(torch, "npu") and torch.npu.is_available(): return torch.npu.synchronize elif device.type == "cpu": return _do_nothing else: raise RuntimeError("Non supported device") else: return _do_nothing def _setup_policy_version_tracking( self, track_policy_version: bool | PolicyVersion ) -> None: """Set up policy version tracking if requested.""" self.policy_version_tracker = track_policy_version if isinstance(track_policy_version, bool) and track_policy_version: from torchrl.envs.batched_envs import BatchedEnvBase if isinstance(self.env, BatchedEnvBase): raise RuntimeError( "BatchedEnvBase is not supported for policy version tracking. Please add the PolicyVersion transform to the environment manually, " "and pass that transform to the collector." ) self.policy_version_tracker = PolicyVersion() self.env = self.env.append_transform(self.policy_version_tracker) # type: ignore elif hasattr(track_policy_version, "increment_version"): self.policy_version_tracker = track_policy_version self.env = self.env.append_transform(self.policy_version_tracker) # type: ignore else: self.policy_version_tracker = None def _setup_replay_buffer( self, replay_buffer: ReplayBuffer | None, extend_buffer: bool, local_init_rb: bool | None, postproc: Callable | None, split_trajs: bool | None, return_same_td: bool, use_buffers: bool | None, ) -> None: """Set up replay buffer configuration and validate compatibility.""" self.replay_buffer = replay_buffer self.extend_buffer = extend_buffer # Handle local_init_rb deprecation if local_init_rb is None: local_init_rb = False if replay_buffer is not None and not local_init_rb: warnings.warn( "local_init_rb=False is deprecated and will be removed in v0.12. " "The new storage-level initialization provides better performance.", FutureWarning, ) self.local_init_rb = local_init_rb # Validate replay buffer compatibility if self.replay_buffer is not None and not self._ignore_rb: if postproc is not None and not self.extend_buffer: raise TypeError( "postproc must be None when a replay buffer is passed, or extend_buffer must be set to True." ) if split_trajs not in (None, False) and not self.extend_buffer: raise TypeError( "split_trajs must be None/False when a replay buffer is passed, or extend_buffer must be set to True." ) if return_same_td: raise TypeError( "return_same_td must be False when a replay buffer is passed, or extend_buffer must be set to True." ) if use_buffers: raise TypeError("replay_buffer is exclusive with use_buffers.") if use_buffers is None: use_buffers = not self.env._has_dynamic_specs and self.replay_buffer is None self._use_buffers = use_buffers def _setup_policy_and_weights(self, policy: TensorDictModule | Callable) -> None: """Set up policy, wrapped policy, and extract weights.""" # Store weak reference to original policy before any transformations # This allows update_policy_weights_ to sync from the original when no scheme is configured if isinstance(policy, nn.Module): self._orig_policy_ref = weakref.ref(policy) else: self._orig_policy_ref = None # Check if policy has meta-device parameters (sent from weight sync schemes) # In that case, skip device placement - weights will come from the receiver has_meta_params = False if isinstance(policy, nn.Module): for p in policy.parameters(): if p.device.type == "meta": has_meta_params = True break if has_meta_params: # Policy has meta params - sent from weight sync schemes # Skip device placement, weights will come from receiver # Keep policy on meta device until weights are loaded if not self.trust_policy: self.policy = policy env = getattr(self, "env", None) try: wrapped_policy = _make_compatible_policy( policy=policy, observation_spec=getattr(env, "observation_spec", None), env=self.env, ) except (TypeError, AttributeError, ValueError) as err: raise TypeError( "Failed to wrap the policy. If the policy needs to be trusted, set trust_policy=True. Scroll up for more details." ) from err self._wrapped_policy = wrapped_policy else: self.policy = self._wrapped_policy = policy # For meta-parameter policies, keep the internal (worker-side) policy # as the reference for collector state_dict / load_state_dict. if isinstance(self.policy, nn.Module): self._policy_w_state_dict = self.policy # Don't extract weights yet - they're on meta device (empty) self.policy_weights = TensorDict() self.get_weights_fn = None else: # Normal path: move policy to correct device policy, self.get_weights_fn = self._get_policy_and_device(policy=policy) if not self.trust_policy: self.policy = policy env = getattr(self, "env", None) try: wrapped_policy = _make_compatible_policy( policy=policy, observation_spec=getattr(env, "observation_spec", None), env=self.env, ) except (TypeError, AttributeError, ValueError) as err: raise TypeError( "Failed to wrap the policy. If the policy needs to be trusted, set trust_policy=True. Scroll up for more details." ) from err self._wrapped_policy = wrapped_policy else: self.policy = self._wrapped_policy = policy # Use the internal, unwrapped policy (cast to the correct device) as the # reference for state_dict / load_state_dict and legacy weight extractors. if isinstance(self.policy, nn.Module): self._policy_w_state_dict = self.policy # Extract policy weights from the uncompiled wrapped policy # Access _wrapped_policy_uncompiled directly to avoid triggering compilation. if isinstance(self._wrapped_policy_uncompiled, nn.Module): self.policy_weights = TensorDict.from_module( self._wrapped_policy_uncompiled, as_module=True ).data else: self.policy_weights = TensorDict() # If policy doesn't have meta params, compile immediately # Otherwise, defer until first use (after weights are loaded) if not has_meta_params and (self.compiled_policy or self.cudagraphed_policy): self._wrapped_policy_maybe_compiled = self._compile_wrapped_policy( self._wrapped_policy_uncompiled ) def _compile_wrapped_policy(self, policy): """Apply compilation and/or cudagraph to a policy.""" if self.compiled_policy: policy = compile_with_warmup(policy, **self.compiled_policy_kwargs) if self.cudagraphed_policy: policy = CudaGraphModule( policy, in_keys=[], out_keys=[], device=self.policy_device, **self.cudagraphed_policy_kwargs, ) return policy @property def _wrapped_policy(self): """Returns the compiled policy, compiling it lazily if needed.""" if (policy := self._wrapped_policy_maybe_compiled) is None: if self.compiled_policy or self.cudagraphed_policy: policy = ( self._wrapped_policy_maybe_compiled ) = self._compile_wrapped_policy(self._wrapped_policy_uncompiled) else: policy = ( self._wrapped_policy_maybe_compiled ) = self._wrapped_policy_uncompiled return policy @property def _orig_policy(self): """Returns the original policy passed to the collector, if still alive.""" if self._orig_policy_ref is not None: return self._orig_policy_ref() return None @_wrapped_policy.setter def _wrapped_policy(self, value): """Allow setting the wrapped policy during initialization.""" self._wrapped_policy_uncompiled = value self._wrapped_policy_maybe_compiled = None def _apply_env_device(self) -> None: """Apply device to environment if specified.""" if self.env_device: self.env: EnvBase = self.env.to(self.env_device) elif self.env.device is not None: # Use the device of the env if none was provided self.env_device = self.env.device # Check if we need to cast to env device self._cast_to_env_device = self._cast_to_policy_device or ( self.env.device != self.storing_device ) def _setup_max_frames_per_traj(self, max_frames_per_traj: int | None) -> None: """Set up maximum frames per trajectory and add StepCounter if needed.""" self.max_frames_per_traj = ( int(max_frames_per_traj) if max_frames_per_traj is not None else 0 ) if self.max_frames_per_traj is not None and self.max_frames_per_traj > 0: # Check that there is no StepCounter yet for key in self.env.output_spec.keys(True, True): if isinstance(key, str): key = (key,) if "step_count" in key: raise ValueError( "A 'step_count' key is already present in the environment " "and the 'max_frames_per_traj' argument may conflict with " "a 'StepCounter' that has already been set. " "Possible solutions: Set max_frames_per_traj to 0 or " "remove the StepCounter limit from the environment transforms." ) self.env = TransformedEnv( self.env, StepCounter(max_steps=self.max_frames_per_traj) ) def _setup_total_frames(self, total_frames: int, frames_per_batch: int) -> None: """Validate and set total frames.""" if total_frames is None or total_frames < 0: total_frames = float("inf") else: remainder = total_frames % frames_per_batch if remainder != 0 and RL_WARNINGS: warnings.warn( f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({frames_per_batch}). " f"This means {frames_per_batch - remainder} additional frames will be collected." "To silence this message, set the environment variable RL_WARNINGS to False." ) self.total_frames = ( int(total_frames) if total_frames != float("inf") else total_frames ) def _setup_init_random_frames( self, init_random_frames: int | None, frames_per_batch: int ) -> None: """Set up initial random frames.""" self.init_random_frames = ( int(init_random_frames) if init_random_frames not in (None, -1) else 0 ) if ( init_random_frames not in (-1, None, 0) and init_random_frames % frames_per_batch != 0 and RL_WARNINGS ): warnings.warn( f"init_random_frames ({init_random_frames}) is not exactly a multiple of frames_per_batch ({frames_per_batch}), " f" this results in more init_random_frames than requested" f" ({-(-init_random_frames // frames_per_batch) * frames_per_batch})." "To silence this message, set the environment variable RL_WARNINGS to False." ) def _setup_postproc(self, postproc: Callable | None) -> None: """Set up post-processing transform.""" self.postproc = postproc if ( self.postproc is not None and hasattr(self.postproc, "to") and self.storing_device ): postproc = self.postproc.to(self.storing_device) if postproc is not self.postproc and postproc is not None: self.postproc = postproc def _setup_frames_per_batch(self, frames_per_batch: int) -> None: """Calculate and validate frames per batch.""" if frames_per_batch % self.n_env != 0 and RL_WARNINGS: warnings.warn( f"frames_per_batch ({frames_per_batch}) is not exactly divisible by the number of batched environments ({self.n_env}), " f" this results in more frames_per_batch per iteration that requested" f" ({-(-frames_per_batch // self.n_env) * self.n_env}). " "To silence this message, set the environment variable RL_WARNINGS to False." ) self.frames_per_batch = -(-frames_per_batch // self.n_env) self.requested_frames_per_batch = self.frames_per_batch * self.n_env def _setup_weight_sync( self, weight_updater: WeightUpdaterBase | Callable | None, weight_sync_schemes: dict[str, WeightSyncScheme] | None, ) -> None: """Set up weight synchronization system.""" if weight_sync_schemes is not None: # Use new simplified weight synchronization system self._weight_sync_schemes = weight_sync_schemes # Initialize and synchronize schemes that need sender-side setup # (e.g., RayModuleTransformScheme for updating transforms in the env) for model_id, scheme in weight_sync_schemes.items(): if not scheme.initialized_on_sender: scheme.init_on_sender(model_id=model_id, context=self) if not scheme.synchronized_on_sender: scheme.connect() self.weight_updater = None # Don't use legacy system elif weight_updater is not None: # Use legacy weight updater system if explicitly provided if not isinstance(weight_updater, WeightUpdaterBase): if callable(weight_updater): weight_updater = weight_updater() else: raise TypeError( f"weight_updater must be a subclass of WeightUpdaterBase. Got {type(weight_updater)} instead." ) warnings.warn( "Using WeightUpdaterBase is deprecated. Please use weight_sync_schemes instead. " "This will be removed in a future version.", DeprecationWarning, stacklevel=2, ) self.weight_updater = weight_updater self._weight_sync_schemes = None else: # No weight sync needed for single-process collectors self.weight_updater = None self._weight_sync_schemes = None @property def _traj_pool(self): pool = getattr(self, "_traj_pool_val", None) if pool is None: pool = self._traj_pool_val = _TrajectoryPool() return pool def _make_shuttle(self): # Shuttle is a deviceless tensordict that just carried data from env to policy and policy to env with torch.no_grad(): self._carrier = self.env.reset() if self.policy_device != self.env_device or self.env_device is None: self._shuttle_has_no_device = True self._carrier.clear_device_() else: self._shuttle_has_no_device = False traj_ids = self._traj_pool.get_traj_and_increment( self.n_env, device=self.storing_device ).view(self.env.batch_size) self._carrier.set( ("collector", "traj_ids"), traj_ids, ) def _maybe_make_final_rollout(self, make_rollout: bool): if make_rollout: with torch.no_grad(): self._final_rollout = self.env.fake_tensordict() # If storing device is not None, we use this to cast the storage. # If it is None and the env and policy are on the same device, # the storing device is already the same as those, so we don't need # to consider this use case. # In all other cases, we can't really put a device on the storage, # since at least one data source has a device that is not clear. if self.storing_device: self._final_rollout = self._final_rollout.to( self.storing_device, non_blocking=True ) else: # erase all devices self._final_rollout.clear_device_() # Check if policy has meta-device parameters (not yet initialized) has_meta_params = False if hasattr(self, "_wrapped_policy_uncompiled") and isinstance( self._wrapped_policy_uncompiled, nn.Module ): for p in self._wrapped_policy_uncompiled.parameters(): if p.device.type == "meta": has_meta_params = True break # If the policy has a valid spec, we use it self._policy_output_keys = set() if ( make_rollout and hasattr( self._wrapped_policy_uncompiled if has_meta_params else self._wrapped_policy, "spec", ) and ( self._wrapped_policy_uncompiled if has_meta_params else self._wrapped_policy ).spec is not None and all( v is not None for v in ( self._wrapped_policy_uncompiled if has_meta_params else self._wrapped_policy ).spec.values(True, True) ) ): if any( key not in self._final_rollout.keys(isinstance(key, tuple)) for key in ( self._wrapped_policy_uncompiled if has_meta_params else self._wrapped_policy ).spec.keys(True, True) ): # if policy spec is non-empty, all the values are not None and the keys # match the out_keys we assume the user has given all relevant information # the policy could have more keys than the env: policy_spec = ( self._wrapped_policy_uncompiled if has_meta_params else self._wrapped_policy ).spec if policy_spec.ndim < self._final_rollout.ndim: policy_spec = policy_spec.expand(self._final_rollout.shape) for key, spec in policy_spec.items(True, True): self._policy_output_keys.add(key) if key in self._final_rollout.keys(True): continue self._final_rollout.set(key, spec.zero()) elif ( not make_rollout and hasattr( self._wrapped_policy_uncompiled if has_meta_params else self._wrapped_policy, "out_keys", ) and ( self._wrapped_policy_uncompiled if has_meta_params else self._wrapped_policy ).out_keys ): self._policy_output_keys = list( ( self._wrapped_policy_uncompiled if has_meta_params else self._wrapped_policy ).out_keys ) elif has_meta_params: # Policy has meta params and no spec/out_keys - defer initialization # Mark that we need to initialize later when weights are loaded self._policy_output_keys = set() if make_rollout: # We'll populate keys on first actual rollout after weights are loaded self._final_rollout_needs_init = True else: if make_rollout: # otherwise, we perform a small number of steps with the policy to # determine the relevant keys with which to pre-populate _final_rollout. # This is the safest thing to do if the spec has None fields or if there is # no spec at all. # See #505 for additional context. self._final_rollout.update(self._carrier.copy()) with torch.no_grad(): policy_input = self._carrier.copy() if self.policy_device: policy_input = policy_input.to(self.policy_device) # we cast to policy device, we'll deal with the device later policy_input_copy = policy_input.copy() policy_input_clone = ( policy_input.clone() ) # to test if values have changed in-place if self.compiled_policy: cudagraph_mark_step_begin() policy_output = self._wrapped_policy(policy_input) # check that we don't have exclusive keys, because they don't appear in keys def check_exclusive(val): if ( isinstance(val, LazyStackedTensorDict) and val._has_exclusive_keys ): raise RuntimeError( "LazyStackedTensorDict with exclusive keys are not permitted in collectors. " "Consider using a placeholder for missing keys." ) policy_output._fast_apply( check_exclusive, call_on_nested=True, filter_empty=True ) # Use apply, because it works well with lazy stacks # Edge-case of this approach: the policy may change the values in-place and only by a tiny bit # or occasionally. In these cases, the keys will be missed (we can't detect if the policy has # changed them here). # This will cause a failure to update entries when policy and env device mismatch and # casting is necessary. def filter_policy(name, value_output, value_input, value_input_clone): if (value_input is None) or ( (value_output is not value_input) and ( value_output.device != value_input_clone.device or ~torch.isclose(value_output, value_input_clone).any() ) ): return value_output filtered_policy_output = policy_output.apply( filter_policy, policy_input_copy, policy_input_clone, default=None, filter_empty=True, named=True, ) self._policy_output_keys = list( self._policy_output_keys.union( set(filtered_policy_output.keys(True, True)) ) ) if make_rollout: self._final_rollout.update( policy_output.select(*self._policy_output_keys) ) del filtered_policy_output, policy_output, policy_input _env_output_keys = [] for spec in ["full_observation_spec", "full_done_spec", "full_reward_spec"]: _env_output_keys += list(self.env.output_spec[spec].keys(True, True)) self._env_output_keys = _env_output_keys if make_rollout: self._final_rollout = ( self._final_rollout.unsqueeze(-1) .expand(*self.env.batch_size, self.frames_per_batch) .clone() .zero_() ) # in addition to outputs of the policy, we add traj_ids to # _final_rollout which will be collected during rollout self._final_rollout.set( ("collector", "traj_ids"), torch.zeros( *self._final_rollout.batch_size, dtype=torch.int64, device=self.storing_device, ), ) self._final_rollout.refine_names(..., "time") def _set_truncated_keys(self): self._truncated_keys = [] if self.set_truncated: if not any(_ends_with(key, "truncated") for key in self.env.done_keys): raise RuntimeError( "set_truncated was set to True but no truncated key could be found " "in the environment. Make sure the truncated keys are properly set using " "`env.add_truncated_keys()` before passing the env to the collector." ) self._truncated_keys = [ key for key in self.env.done_keys if _ends_with(key, "truncated") ] @classmethod def _get_devices( cls, *, storing_device: torch.device, policy_device: torch.device, env_device: torch.device, device: torch.device, ): device = _make_ordinal_device(torch.device(device) if device else device) storing_device = _make_ordinal_device( torch.device(storing_device) if storing_device else device ) policy_device = _make_ordinal_device( torch.device(policy_device) if policy_device else device ) env_device = _make_ordinal_device( torch.device(env_device) if env_device else device ) if storing_device is None and (env_device == policy_device): storing_device = env_device return storing_device, policy_device, env_device # for RPC def next(self): return super().next() # for RPC
[docs] def update_policy_weights_( self, policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, *, worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, **kwargs, ) -> None: if "policy_weights" in kwargs: warnings.warn( "`policy_weights` is deprecated. Use `policy_or_weights` instead.", DeprecationWarning, ) policy_or_weights = kwargs.pop("policy_weights") super().update_policy_weights_( policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs )
def _maybe_fallback_update( self, policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, *, model_id: str | None = None, ) -> None: """Copy weights from original policy to internal policy when no scheme configured.""" if model_id is not None and model_id != "policy": return # Get source weights - either from argument or from original policy if policy_or_weights is not None: weights = self._extract_weights_if_needed(policy_or_weights, "policy") elif self._orig_policy is not None: weights = TensorDict.from_module(self._orig_policy) else: return # Apply to internal policy if ( hasattr(self, "_policy_w_state_dict") and self._policy_w_state_dict is not None ): TensorDict.from_module(self._policy_w_state_dict).data.update_(weights.data)
[docs] def set_seed(self, seed: int, static_seed: bool = False) -> int: """Sets the seeds of the environments stored in the DataCollector. Args: seed (int): integer representing the seed to be used for the environment. static_seed(bool, optional): if ``True``, the seed is not incremented. Defaults to False Returns: Output seed. This is useful when more than one environment is contained in the DataCollector, as the seed will be incremented for each of these. The resulting seed is the seed of the last environment. Examples: >>> from torchrl.envs import ParallelEnv >>> from torchrl.envs.libs.gym import GymEnv >>> from tensordict.nn import TensorDictModule >>> from torch import nn >>> env_fn = lambda: GymEnv("Pendulum-v1") >>> env_fn_parallel = ParallelEnv(6, env_fn) >>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) >>> collector = Collector(env_fn_parallel, policy, total_frames=300, frames_per_batch=100) >>> out_seed = collector.set_seed(1) # out_seed = 6 """ out = self.env.set_seed(seed, static_seed=static_seed) return out
def _increment_frames(self, numel): self._frames += numel completed = self._frames >= self.total_frames if completed: self.env.close() return completed
[docs] def iterator(self) -> Iterator[TensorDictBase]: """Iterates through the DataCollector. Yields: TensorDictBase objects containing (chunks of) trajectories """ if ( not self.no_cuda_sync and self.storing_device and self.storing_device.type == "cuda" ): stream = torch.cuda.Stream(self.storing_device, priority=-1) event = stream.record_event() streams = [stream] events = [event] elif not self.no_cuda_sync and self.storing_device is None: streams = [] events = [] # this way of checking cuda is robust to lazy stacks with mismatching shapes cuda_devices = set() def cuda_check(tensor: torch.Tensor): if tensor.is_cuda: cuda_devices.add(tensor.device) if not self._use_buffers: # This may be a bit dangerous as `torch.device("cuda")` may not have a precise # device associated, whereas `tensor.device` always has for spec in self.env.specs.values(True, True): if spec.device is not None and spec.device.type == "cuda": if ":" not in str(spec.device): raise RuntimeError( "A cuda spec did not have a device associated. Make sure to " "pass `'cuda:device_num'` to each spec device." ) cuda_devices.add(spec.device) else: self._final_rollout.apply(cuda_check, filter_empty=True) for device in cuda_devices: streams.append(torch.cuda.Stream(device, priority=-1)) events.append(streams[-1].record_event()) else: streams = [] events = [] with contextlib.ExitStack() as stack: for stream in streams: stack.enter_context(torch.cuda.stream(stream)) while self._frames < self.total_frames: self._iter += 1 torchrl_logger.debug("Collector: rollout.") tensordict_out = self.rollout() if tensordict_out is None: # if a replay buffer is passed and self.extend_buffer=False, there is no tensordict_out # frames are updated within the rollout function torchrl_logger.debug("Collector: No tensordict_out. Yielding.") yield continue self._increment_frames(tensordict_out.numel()) tensordict_out = self._postproc(tensordict_out) torchrl_logger.debug("Collector: postproc done.") if self.return_same_td: # This is used with multiprocessed collectors to use the buffers # stored in the tensordict. if events: for event in events: event.record() event.synchronize() yield tensordict_out elif self.replay_buffer is not None and not self._ignore_rb: self.replay_buffer.extend(tensordict_out) torchrl_logger.debug( f"Collector: Added {tensordict_out.numel()} frames to replay buffer. " "Buffer write count: {self.replay_buffer.write_count}. Yielding." ) yield else: # we must clone the values, as the tensordict is updated in-place. # otherwise the following code may break: # >>> for i, data in enumerate(collector): # >>> if i == 0: # >>> data0 = data # >>> elif i == 1: # >>> data1 = data # >>> else: # >>> break # >>> assert data0["done"] is not data1["done"] yield tensordict_out.clone()
[docs] def start(self): """Starts the collector in a separate thread for asynchronous data collection. The collected data is stored in the provided replay buffer. This method is useful when you want to decouple data collection from training, allowing your training loop to run independently of the data collection process. Raises: RuntimeError: If no replay buffer is defined during the collector's initialization. Example: >>> from torchrl.modules import RandomPolicy >>> >>> import time >>> from functools import partial >>> >>> import tqdm >>> >>> from torchrl.collectors import Collector >>> from torchrl.data import LazyTensorStorage, ReplayBuffer >>> from torchrl.envs import GymEnv, set_gym_backend >>> import ale_py >>> >>> # Set the gym backend to gymnasium >>> set_gym_backend("gymnasium").set() >>> >>> if __name__ == "__main__": ... # Create a random policy for the Pong environment ... env = GymEnv("ALE/Pong-v5") ... policy = RandomPolicy(env.action_spec) ... ... # Initialize a shared replay buffer ... rb = ReplayBuffer(storage=LazyTensorStorage(1000), shared=True) ... ... # Create a synchronous data collector ... collector = Collector( ... env, ... policy=policy, ... replay_buffer=rb, ... frames_per_batch=256, ... total_frames=-1, ... ) ... ... # Progress bar to track the number of collected frames ... pbar = tqdm.tqdm(total=100_000) ... ... # Start the collector asynchronously ... collector.start() ... ... # Track the write count of the replay buffer ... prec_wc = 0 ... while True: ... wc = rb.write_count ... c = wc - prec_wc ... prec_wc = wc ... ... # Update the progress bar ... pbar.update(c) ... pbar.set_description(f"Write Count: {rb.write_count}") ... ... # Check the write count every 0.5 seconds ... time.sleep(0.5) ... ... # Stop when the desired number of frames is reached ... if rb.write_count . 100_000: ... break ... ... # Shut down the collector ... collector.async_shutdown() """ if self.replay_buffer is None: raise RuntimeError("Replay buffer must be defined for execution.") if not self.is_running(): self._stop = False self._thread = threading.Thread(target=self._run_iterator) self._thread.daemon = ( True # So that the thread dies when the main program exits ) self._thread.start()
def _run_iterator(self): for _ in self: if self._stop: return def is_running(self): return hasattr(self, "_thread") and self._thread.is_alive()
[docs] def async_shutdown( self, timeout: float | None = None, close_env: bool = True ) -> None: """Finishes processes started by ray.init() during async execution.""" self._stop = True if hasattr(self, "_thread") and self._thread.is_alive(): self._thread.join(timeout=timeout) self.shutdown(close_env=close_env)
def _postproc(self, tensordict_out): if self.split_trajs: tensordict_out = split_trajectories(tensordict_out, prefix="collector") if self.postproc is not None: tensordict_out = self.postproc(tensordict_out) if self._exclude_private_keys: def is_private(key): if isinstance(key, str) and key.startswith("_"): return True if isinstance(key, tuple) and any(_key.startswith("_") for _key in key): return True return False excluded_keys = [ key for key in tensordict_out.keys(True) if is_private(key) ] tensordict_out = tensordict_out.exclude(*excluded_keys, inplace=True) return tensordict_out def _update_traj_ids(self, env_output) -> None: # we can't use the reset keys because they're gone traj_sop = _aggregate_end_of_traj( env_output.get("next"), done_keys=self.env.done_keys ) if traj_sop.any(): device = self.storing_device traj_ids = self._carrier.get(("collector", "traj_ids")) if device is not None: traj_ids = traj_ids.to(device) traj_sop = traj_sop.to(device) elif traj_sop.device != traj_ids.device: traj_sop = traj_sop.to(traj_ids.device) pool = self._traj_pool new_traj = pool.get_traj_and_increment( traj_sop.sum(), device=traj_sop.device ) traj_ids = traj_ids.masked_scatter(traj_sop, new_traj) self._carrier.set(("collector", "traj_ids"), traj_ids)
[docs] @torch.no_grad() def rollout(self) -> TensorDictBase: """Computes a rollout in the environment using the provided policy. Returns: TensorDictBase containing the computed rollout. """ if self.reset_at_each_iter: self._carrier.update(self.env.reset()) # self._shuttle.fill_(("collector", "step_count"), 0) if self._use_buffers: self._final_rollout.fill_(("collector", "traj_ids"), -1) else: pass tensordicts = [] with set_exploration_type(self.exploration_type): for t in range(self.frames_per_batch): if ( self.init_random_frames is not None and self._frames < self.init_random_frames ): self.env.rand_action(self._carrier) if ( self.policy_device is not None and self.policy_device != self.env_device ): # TODO: This may break with exclusive / ragged lazy stacks self._carrier.apply( lambda name, val: val.to( device=self.policy_device, non_blocking=True ) if name in self._policy_output_keys else val, out=self._carrier, named=True, nested_keys=True, ) else: if self._cast_to_policy_device: if self.policy_device is not None: # This is unsafe if the shuttle is in pin_memory -- otherwise cuda will be happy with non_blocking non_blocking = ( not self.no_cuda_sync or self.policy_device.type == "cuda" ) policy_input = self._carrier.to( self.policy_device, non_blocking=non_blocking, ) if not self.no_cuda_sync: self._sync_policy() elif self.policy_device is None: # we know the tensordict has a device otherwise we would not be here # we can pass this, clear_device_ must have been called earlier # policy_input = self._shuttle.clear_device_() policy_input = self._carrier else: policy_input = self._carrier # we still do the assignment for security if self.compiled_policy: cudagraph_mark_step_begin() policy_output = self._wrapped_policy(policy_input) if self.compiled_policy: policy_output = policy_output.clone() if self._carrier is not policy_output: # ad-hoc update shuttle self._carrier.update( policy_output, keys_to_update=self._policy_output_keys ) if self._cast_to_env_device: if self.env_device is not None: non_blocking = ( not self.no_cuda_sync or self.env_device.type == "cuda" ) env_input = self._carrier.to( self.env_device, non_blocking=non_blocking ) if not self.no_cuda_sync: self._sync_env() elif self.env_device is None: # we know the tensordict has a device otherwise we would not be here # we can pass this, clear_device_ must have been called earlier # env_input = self._shuttle.clear_device_() env_input = self._carrier else: env_input = self._carrier env_output, env_next_output = self.env.step_and_maybe_reset(env_input) if self._carrier is not env_output: # ad-hoc update shuttle next_data = env_output.get("next") if self._shuttle_has_no_device: # Make sure next_data.clear_device_() self._carrier.set("next", next_data) if ( self.replay_buffer is not None and not self._ignore_rb and not self.extend_buffer ): torchrl_logger.debug( f"Collector: Adding {env_output.numel()} frames to replay buffer using add()." ) self.replay_buffer.add(self._carrier) if self._increment_frames(self._carrier.numel()): return else: if self.storing_device is not None: torchrl_logger.debug( f"Collector: Moving to {self.storing_device} and adding to queue." ) non_blocking = ( not self.no_cuda_sync or self.storing_device.type == "cuda" ) tensordicts.append( self._carrier.to( self.storing_device, non_blocking=non_blocking ) ) if not self.no_cuda_sync: self._sync_storage() else: tensordicts.append(self._carrier) # carry over collector data without messing up devices collector_data = self._carrier.get("collector").copy() self._carrier = env_next_output if self._shuttle_has_no_device: self._carrier.clear_device_() self._carrier.set("collector", collector_data) self._update_traj_ids(env_output) if ( self.interruptor is not None and self.interruptor.collection_stopped() ): torchrl_logger.debug("Collector: Interruptor stopped.") if ( self.replay_buffer is not None and not self._ignore_rb and not self.extend_buffer ): return result = self._final_rollout if self._use_buffers: try: torch.stack( tensordicts, self._final_rollout.ndim - 1, out=self._final_rollout[..., : t + 1], ) except RuntimeError: with self._final_rollout.unlock_(): torch.stack( tensordicts, self._final_rollout.ndim - 1, out=self._final_rollout[..., : t + 1], ) else: result = TensorDict.maybe_dense_stack(tensordicts, dim=-1) break else: if self._use_buffers: torchrl_logger.debug("Returning final rollout within buffer.") result = self._final_rollout try: result = torch.stack( tensordicts, self._final_rollout.ndim - 1, out=self._final_rollout, ) except RuntimeError: with self._final_rollout.unlock_(): result = torch.stack( tensordicts, self._final_rollout.ndim - 1, out=self._final_rollout, ) elif ( self.replay_buffer is not None and not self._ignore_rb and not self.extend_buffer ): return else: torchrl_logger.debug( "Returning final rollout with NO buffer (maybe_dense_stack)." ) result = TensorDict.maybe_dense_stack(tensordicts, dim=-1) result.refine_names(..., "time") return self._maybe_set_truncated(result)
def _maybe_set_truncated(self, final_rollout): last_step = (slice(None),) * (final_rollout.ndim - 1) + (-1,) for truncated_key in self._truncated_keys: truncated = final_rollout["next", truncated_key] truncated[last_step] = True final_rollout["next", truncated_key] = truncated done = final_rollout["next", _replace_last(truncated_key, "done")] final_rollout["next", _replace_last(truncated_key, "done")] = ( done | truncated ) return final_rollout
[docs] @torch.no_grad() def reset(self, index=None, **kwargs) -> None: """Resets the environments to a new initial state.""" # metadata collector_metadata = self._carrier.get("collector").clone() if index is not None: # check that the env supports partial reset if prod(self.env.batch_size) == 0: raise RuntimeError("resetting unique env with index is not permitted.") for reset_key, done_keys in zip( self.env.reset_keys, self.env.done_keys_groups ): _reset = torch.zeros( self.env.full_done_spec[done_keys[0]].shape, dtype=torch.bool, device=self.env.device, ) _reset[index] = 1 self._carrier.set(reset_key, _reset) else: _reset = None self._carrier.zero_() self._carrier.update(self.env.reset(**kwargs), inplace=True) collector_metadata["traj_ids"] = ( collector_metadata["traj_ids"] - collector_metadata["traj_ids"].min() ) self._carrier["collector"] = collector_metadata
[docs] def shutdown( self, timeout: float | None = None, close_env: bool = True, raise_on_error: bool = True, ) -> None: """Shuts down all workers and/or closes the local environment. Args: timeout (float, optional): The timeout for closing pipes between workers. No effect for this class. close_env (bool, optional): Whether to close the environment. Defaults to `True`. raise_on_error (bool, optional): Whether to raise an error if the shutdown fails. Defaults to `True`. """ try: if not self.closed: self.closed = True del self._carrier if self._use_buffers: del self._final_rollout if close_env and not self.env.is_closed: self.env.close(raise_if_closed=raise_on_error) del self.env return except Exception as e: if raise_on_error: raise e else: pass
def __del__(self): try: self.shutdown() except Exception: # an AttributeError will typically be raised if the collector is deleted when the program ends. # In the future, insignificant changes to the close method may change the error type. # We excplicitely assume that any error raised during closure in # __del__ will not affect the program. pass
[docs] def state_dict(self) -> OrderedDict: """Returns the local state_dict of the data collector (environment and policy). Returns: an ordered dictionary with fields :obj:`"policy_state_dict"` and `"env_state_dict"`. """ from torchrl.envs.batched_envs import BatchedEnvBase if isinstance(self.env, TransformedEnv): env_state_dict = self.env.transform.state_dict() elif isinstance(self.env, BatchedEnvBase): env_state_dict = self.env.state_dict() else: env_state_dict = OrderedDict() if hasattr(self, "_policy_w_state_dict"): policy_state_dict = self._policy_w_state_dict.state_dict() state_dict = OrderedDict( policy_state_dict=policy_state_dict, env_state_dict=env_state_dict, ) else: state_dict = OrderedDict(env_state_dict=env_state_dict) state_dict.update({"frames": self._frames, "iter": self._iter}) return state_dict
[docs] def load_state_dict(self, state_dict: OrderedDict, **kwargs) -> None: """Loads a state_dict on the environment and policy. Args: state_dict (OrderedDict): ordered dictionary containing the fields `"policy_state_dict"` and :obj:`"env_state_dict"`. """ strict = kwargs.get("strict", True) if strict or "env_state_dict" in state_dict: self.env.load_state_dict(state_dict["env_state_dict"], **kwargs) if strict or "policy_state_dict" in state_dict: if not hasattr(self, "_policy_w_state_dict"): raise ValueError( "Underlying policy does not have state_dict to load policy_state_dict into." ) self._policy_w_state_dict.load_state_dict( state_dict["policy_state_dict"], **kwargs ) self._frames = state_dict["frames"] self._iter = state_dict["iter"]
def __repr__(self) -> str: try: env_str = indent(f"env={self.env}", 4 * " ") policy_str = indent(f"policy={self._wrapped_policy}", 4 * " ") td_out_str = repr(getattr(self, "_final_rollout", None)) if len(td_out_str) > 50: td_out_str = td_out_str[:50] + "..." td_out_str = indent(f"td_out={td_out_str}", 4 * " ") string = ( f"{self.__class__.__name__}(" f"\n{env_str}," f"\n{policy_str}," f"\n{td_out_str}," f"\nexploration={self.exploration_type})" ) return string except Exception: return f"{type(self).__name__}(not_init)"
[docs] def increment_version(self): """Increment the policy version.""" if self.policy_version_tracker is not None: if not hasattr(self.policy_version_tracker, "increment_version"): raise RuntimeError( "Policy version tracker is not a PolicyVersion instance. Please pass a PolicyVersion instance to the collector." ) self.policy_version_tracker.increment_version()
@property def policy_version(self) -> str | int | None: """The current policy version.""" if not hasattr(self.policy_version_tracker, "version"): return None return self.policy_version_tracker.version
[docs] def get_policy_version(self) -> str | int | None: """Get the current policy version. This method exists to support remote calls in Ray actors, since properties cannot be accessed directly through Ray's RPC mechanism. Returns: The current version number (int) or UUID (str), or None if version tracking is disabled. """ return self.policy_version
[docs] def getattr_policy(self, attr): """Get an attribute from the policy.""" # send command to policy to return the attr return getattr(self._wrapped_policy, attr)
[docs] def getattr_env(self, attr): """Get an attribute from the environment.""" # send command to env to return the attr return getattr(self.env, attr)
[docs] def getattr_rb(self, attr): """Get an attribute from the replay buffer.""" # send command to rb to return the attr return getattr(self.replay_buffer, attr)
[docs] def get_model(self, model_id: str): """Get model instance by ID (for weight sync schemes). Args: model_id: Model identifier (e.g., "policy", "value_net") Returns: The model instance Raises: ValueError: If model_id is not recognized """ if model_id == "policy": # Return the unwrapped policy instance for weight synchronization # The unwrapped policy has the same parameter structure as what's # extracted in the main process, avoiding key mismatches when # the policy is auto-wrapped (e.g., WrappablePolicy -> TensorDictModule) if hasattr(self, "policy") and self.policy is not None: return self.policy else: raise ValueError(f"No policy found for model_id '{model_id}'") else: return _resolve_model(self, model_id)
def _receive_weights_scheme(self): return super()._receive_weights_scheme()
class SyncDataCollector(Collector, metaclass=_LegacyCollectorMeta): """Deprecated version of :class:`~torchrl.collectors.Collector`.""" ...

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources