Shortcuts

Source code for torchrl.collectors._single_async

from __future__ import annotations

from collections import OrderedDict
from collections.abc import Callable, Sequence
from typing import Any

from tensordict import TensorDictBase
from tensordict.nn import TensorDictModule

from torchrl._utils import accept_remote_rref_udf_invocation
from torchrl.collectors._base import _make_legacy_metaclass
from torchrl.collectors._constants import DEFAULT_EXPLORATION_TYPE, ExplorationType
from torchrl.collectors._multi_async import MultiAsyncCollector
from torchrl.collectors._multi_base import _MultiCollectorMeta
from torchrl.data.utils import DEVICE_TYPING
from torchrl.envs import EnvBase


[docs]@accept_remote_rref_udf_invocation class AsyncCollector(MultiAsyncCollector): """Runs a single DataCollector on a separate process. This is mostly useful for offline RL paradigms where the policy being trained can differ from the policy used to collect data. In online settings, a regular DataCollector should be preferred. This class is merely a wrapper around a MultiAsyncCollector where a single process is being created. Args: create_env_fn (Callabled): Callable returning an instance of EnvBase policy (Callable): Policy to be executed in the environment. Must accept :class:`tensordict.tensordict.TensorDictBase` object as input. If ``None`` is provided, the policy used will be a :class:`~torchrl.collectors.RandomPolicy` instance with the environment ``action_spec``. Accepted policies are usually subclasses of :class:`~tensordict.nn.TensorDictModuleBase`. This is the recommended usage of the collector. Other callables are accepted too: If the policy is not a ``TensorDictModuleBase`` (e.g., a regular :class:`~torch.nn.Module` instances) it will be wrapped in a `nn.Module` first. Then, the collector will try to assess if these modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not. - If the policy forward signature matches any of ``forward(self, tensordict)``, ``forward(self, td)`` or ``forward(self, <anything>: TensorDictBase)`` (or any typing with a single argument typed as a subclass of ``TensorDictBase``) then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`. - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``. .. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized / pickled directly), the ``policy_factory`` should be used instead. Keyword Args: policy_factory (Callable[[], Callable], optional): a callable that returns a policy instance. This is exclusive with the `policy` argument. .. note:: `policy_factory` comes in handy whenever the policy cannot be serialized. frames_per_batch (int): A keyword-only argument representing the total number of elements in a batch. total_frames (int, optional): A keyword-only argument representing the total number of frames returned by the collector during its lifespan. If the ``total_frames`` is not divisible by ``frames_per_batch``, an exception is raised. Endless collectors can be created by passing ``total_frames=-1``. Defaults to ``-1`` (never ending collector). device (int, str or torch.device, optional): The generic device of the collector. The ``device`` args fills any non-specified device: if ``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or ``env_device`` is not specified, its value will be set to ``device``. Defaults to ``None`` (No default device). Supports a list of devices if one wishes to indicate a different device for each worker. The list must be as long as the number of workers. storing_device (int, str or torch.device, optional): The device on which the output :class:`~tensordict.TensorDict` will be stored. If ``device`` is passed and ``storing_device`` is ``None``, it will default to the value indicated by ``device``. For long trajectories, it may be necessary to store the data on a different device than the one where the policy and env are executed. Defaults to ``None`` (the output tensordict isn't on a specific device, leaf tensors sit on the device where they were created). Supports a list of devices if one wishes to indicate a different device for each worker. The list must be as long as the number of workers. env_device (int, str or torch.device, optional): The device on which the environment should be cast (or executed if that functionality is supported). If not specified and the env has a non-``None`` device, ``env_device`` will default to that value. If ``device`` is passed and ``env_device=None``, it will default to ``device``. If the value as such specified of ``env_device`` differs from ``policy_device`` and one of them is not ``None``, the data will be cast to ``env_device`` before being passed to the env (i.e., passing different devices to policy and env is supported). Defaults to ``None``. Supports a list of devices if one wishes to indicate a different device for each worker. The list must be as long as the number of workers. policy_device (int, str or torch.device, optional): The device on which the policy should be cast. If ``device`` is passed and ``policy_device=None``, it will default to ``device``. If the value as such specified of ``policy_device`` differs from ``env_device`` and one of them is not ``None``, the data will be cast to ``policy_device`` before being passed to the policy (i.e., passing different devices to policy and env is supported). Defaults to ``None``. Supports a list of devices if one wishes to indicate a different device for each worker. The list must be as long as the number of workers. create_env_kwargs (dict, optional): A dictionary with the keyword arguments used to create an environment. If a list is provided, each of its elements will be assigned to a sub-collector. max_frames_per_traj (int, optional): Maximum steps per trajectory. Note that a trajectory can span across multiple batches (unless ``reset_at_each_iter`` is set to ``True``, see below). Once a trajectory reaches ``n_steps``, the environment is reset. If the environment wraps multiple environments together, the number of steps is tracked for each environment independently. Negative values are allowed, in which case this argument is ignored. Defaults to ``None`` (i.e. no maximum number of steps). init_random_frames (int, optional): Number of frames for which the policy is ignored before it is called. This feature is mainly intended to be used in offline/model-based settings, where a batch of random trajectories can be used to initialize training. If provided, it will be rounded up to the closest multiple of frames_per_batch. Defaults to ``None`` (i.e. no random frames). reset_at_each_iter (bool, optional): Whether environments should be reset at the beginning of a batch collection. Defaults to ``False``. postproc (Callable, optional): A post-processing transform, such as a :class:`~torchrl.envs.Transform` or a :class:`~torchrl.data.postprocs.MultiStep` instance. Defaults to ``None``. split_trajs (bool, optional): Boolean indicating whether the resulting TensorDict should be split according to the trajectories. See :func:`~torchrl.collectors.utils.split_trajectories` for more information. Defaults to ``False``. exploration_type (ExplorationType, optional): interaction mode to be used when collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``, ``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE`` or ``torchrl.envs.utils.ExplorationType.MEAN``. reset_when_done (bool, optional): if ``True`` (default), an environment that return a ``True`` value in its ``"done"`` or ``"truncated"`` entry will be reset at the corresponding indices. update_at_each_batch (boolm optional): if ``True``, :meth:`update_policy_weights_()` will be called before (sync) or after (async) each data collection. Defaults to ``False``. preemptive_threshold (:obj:`float`, optional): a value between 0.0 and 1.0 that specifies the ratio of workers that will be allowed to finished collecting their rollout before the rest are forced to end early. num_threads (int, optional): number of threads for this process. Defaults to the number of workers. num_sub_threads (int, optional): number of threads of the subprocesses. Should be equal to one plus the number of processes launched within each subprocess (or one if a single process is launched). Defaults to 1 for safety: if none is indicated, launching multiple workers may charge the cpu load too much and harm performance. set_truncated (bool, optional): if ``True``, the truncated signals (and corresponding ``"done"`` but not ``"terminated"``) will be set to ``True`` when the last frame of a rollout is reached. If no ``"truncated"`` key is found, an exception is raised. Truncated keys can be set through ``env.add_truncated_keys``. Defaults to ``False``. track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy. This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment. Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track the policy version. Defaults to `False`. """ def __init__( self, create_env_fn: Callable[[], EnvBase], policy: None | (TensorDictModule | Callable[[TensorDictBase], TensorDictBase]) = None, *, policy_factory: Callable[[], Callable] | None = None, frames_per_batch: int, total_frames: int | None = -1, device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, storing_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, env_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, policy_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, create_env_kwargs: Sequence[dict[str, Any]] | None = None, max_frames_per_traj: int | None = None, init_random_frames: int | None = None, reset_at_each_iter: bool = False, postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, split_trajs: bool | None = None, exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, reset_when_done: bool = True, update_at_each_batch: bool = False, preemptive_threshold: float | None = None, num_threads: int | None = None, num_sub_threads: int = 1, set_truncated: bool = False, track_policy_version: bool = False, **kwargs, ): super().__init__( create_env_fn=[create_env_fn], policy=policy, policy_factory=policy_factory, total_frames=total_frames, create_env_kwargs=[create_env_kwargs] if create_env_kwargs else create_env_kwargs, max_frames_per_traj=max_frames_per_traj, frames_per_batch=frames_per_batch, reset_at_each_iter=reset_at_each_iter, init_random_frames=init_random_frames, postproc=postproc, split_trajs=split_trajs, device=device, policy_device=policy_device, env_device=env_device, storing_device=storing_device, exploration_type=exploration_type, reset_when_done=reset_when_done, update_at_each_batch=update_at_each_batch, preemptive_threshold=preemptive_threshold, num_threads=num_threads, num_sub_threads=num_sub_threads, set_truncated=set_truncated, track_policy_version=track_policy_version, **kwargs, ) # for RPC def next(self): return super().next() # for RPC
[docs] def shutdown( self, timeout: float | None = None, close_env: bool = True, raise_on_error: bool = True, ) -> None: return super().shutdown( timeout=timeout, close_env=close_env, raise_on_error=raise_on_error )
# for RPC
[docs] def set_seed(self, seed: int, static_seed: bool = False) -> int: return super().set_seed(seed, static_seed)
# for RPC
[docs] def state_dict(self) -> OrderedDict: return super().state_dict()
# for RPC
[docs] def load_state_dict(self, state_dict: OrderedDict) -> None: return super().load_state_dict(state_dict)
_LegacyAsyncCollectorMeta = _make_legacy_metaclass(_MultiCollectorMeta) class aSyncDataCollector(AsyncCollector, metaclass=_LegacyAsyncCollectorMeta): """Deprecated version of :class:`~torchrl.collectors.AsyncCollector`.""" ...

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources