Shortcuts

Source code for torchrl.collectors._multi_async

from __future__ import annotations

import time
import warnings
from collections import defaultdict, OrderedDict
from collections.abc import Iterator, Sequence
from copy import deepcopy
from queue import Empty

import torch

from tensordict import TensorDictBase
from tensordict.nn import TensorDictModuleBase
from torchrl._utils import _check_for_faulty_process, accept_remote_rref_udf_invocation
from torchrl.collectors._base import _make_legacy_metaclass
from torchrl.collectors._constants import _MAX_IDLE_COUNT, _TIMEOUT
from torchrl.collectors._multi_base import _MultiCollectorMeta, MultiCollector
from torchrl.collectors.utils import split_trajectories


[docs]@accept_remote_rref_udf_invocation class MultiAsyncCollector(MultiCollector): """Runs a given number of DataCollectors on separate processes asynchronously. .. aafig:: +----------------------------------------------------------------------+ | "MultiConcurrentCollector" | | |~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~| | | "Collector 1" | "Collector 2" | "Collector 3" | "Main" | |~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~| | "env1" | "env2" | "env3" | "env4" | "env5" | "env6" | | |~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~~~~~~~~~| |"reset" |"reset" |"reset" |"reset" |"reset" |"reset" | | | | | | | | | | | "actor" | | | "actor" | | | | | | | | | "step" | "step" | "actor" | | | | | | | | | | | | | "step" | "step" | | | | | | | | | | "actor | "step" | "step" | "actor" | | | | | | | | | "yield batch 1" | "actor" | |"collect, train"| | | | | | | "step" | "step" | | "yield batch 2" |"collect, train"| | | | | | | | | | "yield batch 3" | |"collect, train"| | | | | | | +----------------------------------------------------------------------+ Environment types can be identical or different. The collection keeps on occurring on all processes even between the time the batch of rollouts is collected and the next call to the iterator. This class can be safely used with offline RL sota-implementations. .. note:: Python requires multiprocessed code to be instantiated within a main guard: >>> from torchrl.collectors import MultiAsyncCollector >>> if __name__ == "__main__": ... # Create your collector here See https://docs.python.org/3/library/multiprocessing.html for more info. Examples: >>> from torchrl.envs.libs.gym import GymEnv >>> from tensordict.nn import TensorDictModule >>> from torch import nn >>> from torchrl.collectors import MultiAsyncCollector >>> if __name__ == "__main__": ... env_maker = lambda: GymEnv("Pendulum-v1", device="cpu") ... policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) ... collector = MultiAsyncCollector( ... create_env_fn=[env_maker, env_maker], ... policy=policy, ... total_frames=2000, ... max_frames_per_traj=50, ... frames_per_batch=200, ... init_random_frames=-1, ... reset_at_each_iter=False, ... device="cpu", ... storing_device="cpu", ... cat_results="stack", ... ) ... for i, data in enumerate(collector): ... if i == 2: ... print(data) ... break ... collector.shutdown() ... del collector TensorDict( fields={ action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), collector: TensorDict( fields={ traj_ids: Tensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([200]), device=cpu, is_shared=False), done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([200]), device=cpu, is_shared=False), observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([200]), device=cpu, is_shared=False) """ __doc__ += MultiCollector.__doc__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.out_tensordicts = defaultdict(lambda: None) self.running = False if self.postprocs is not None and self.replay_buffer is None: postproc = self.postprocs self.postprocs = {} for _device in self.storing_device: if _device not in self.postprocs: if hasattr(postproc, "to"): postproc = deepcopy(postproc).to(_device) self.postprocs[_device] = postproc # for RPC def next(self): return super().next() # for RPC
[docs] def shutdown( self, timeout: float | None = None, close_env: bool = True, raise_on_error: bool = True, ) -> None: if hasattr(self, "out_tensordicts"): del self.out_tensordicts if not close_env: raise RuntimeError( f"Cannot shutdown {type(self).__name__} collector without environment being closed." ) return super().shutdown(timeout=timeout, raise_on_error=raise_on_error)
# for RPC
[docs] def set_seed(self, seed: int, static_seed: bool = False) -> int: return super().set_seed(seed, static_seed)
# for RPC
[docs] def state_dict(self) -> OrderedDict: return super().state_dict()
# for RPC
[docs] def load_state_dict(self, state_dict: OrderedDict) -> None: return super().load_state_dict(state_dict)
# for RPC
[docs] def update_policy_weights_( self, policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, *, worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, **kwargs, ) -> None: if "policy_weights" in kwargs: warnings.warn( "`policy_weights` is deprecated. Use `policy_or_weights` instead.", DeprecationWarning, ) policy_or_weights = kwargs.pop("policy_weights") super().update_policy_weights_( policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs )
def frames_per_batch_worker(self, *, worker_idx: int | None = None) -> int: return self.requested_frames_per_batch def _get_from_queue(self, timeout=None) -> tuple[int, int, TensorDictBase]: new_data, j = self.queue_out.get(timeout=timeout) use_buffers = self._use_buffers if self.replay_buffer is not None: idx = new_data elif j == 0 or not use_buffers: try: data, idx = new_data self.out_tensordicts[idx] = data if use_buffers is None and j > 0: use_buffers = self._use_buffers = False except TypeError: if use_buffers is None: use_buffers = self._use_buffers = True idx = new_data else: raise else: idx = new_data out = self.out_tensordicts[idx] if not self.replay_buffer and (j == 0 or use_buffers): # we clone the data to make sure that we'll be working with a fixed copy out = out.clone() return idx, j, out @property def _queue_len(self) -> int: return 1 def iterator(self) -> Iterator[TensorDictBase]: if self.update_at_each_batch: self.update_policy_weights_() for i in range(self.num_workers): if self.init_random_frames is not None and self.init_random_frames > 0: self.pipes[i].send((None, "continue_random")) else: self.pipes[i].send((None, "continue")) self.running = True workers_frames = [0 for _ in range(self.num_workers)] while self._frames < self.total_frames: self._iter += 1 counter = 0 while True: try: idx, j, out = self._get_from_queue(timeout=_TIMEOUT) break except (TimeoutError, Empty): counter += _TIMEOUT _check_for_faulty_process(self.procs) if counter > (_TIMEOUT * _MAX_IDLE_COUNT): raise RuntimeError( f"Failed to gather all collector output within {_TIMEOUT * _MAX_IDLE_COUNT} seconds. " f"Increase the MAX_IDLE_COUNT environment variable to bypass this error." ) if self.replay_buffer is None: worker_frames = out.numel() if self.split_trajs: out = split_trajectories(out, prefix="collector") else: worker_frames = self.frames_per_batch_worker() self._frames += worker_frames workers_frames[idx] = workers_frames[idx] + worker_frames if out is not None and self.postprocs: out = self.postprocs[out.device](out) # the function blocks here until the next item is asked, hence we send the message to the # worker to keep on working in the meantime before the yield statement if ( self.init_random_frames is not None and self._frames < self.init_random_frames ): msg = "continue_random" else: msg = "continue" self.pipes[idx].send((idx, msg)) if out is not None and self._exclude_private_keys: excluded_keys = [key for key in out.keys() if key.startswith("_")] out = out.exclude(*excluded_keys) yield out # We don't want to shutdown yet, the user may want to call state_dict before # self._shutdown_main() self.running = False def _shutdown_main(self, *args, **kwargs) -> None: if hasattr(self, "out_tensordicts"): del self.out_tensordicts return super()._shutdown_main(*args, **kwargs)
[docs] def reset(self, reset_idx: Sequence[bool] | None = None) -> None: super().reset(reset_idx) if self.queue_out.full(): time.sleep(_TIMEOUT) # wait until queue is empty if self.queue_out.full(): raise Exception("self.queue_out is full") if self.running: for idx in range(self.num_workers): if ( self.init_random_frames is not None and self._frames < self.init_random_frames ): self.pipes[idx].send((idx, "continue_random")) else: self.pipes[idx].send((idx, "continue"))
# for RPC def _receive_weights_scheme(self): return super()._receive_weights_scheme() # for RPC
[docs] def receive_weights(self, policy_or_weights: TensorDictBase | None = None): return super().receive_weights(policy_or_weights)
_LegacyMultiAsyncMeta = _make_legacy_metaclass(_MultiCollectorMeta) class MultiaSyncDataCollector(MultiAsyncCollector, metaclass=_LegacyMultiAsyncMeta): """Deprecated version of :class:`~torchrl.collectors.MultiAsyncCollector`.""" ...

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources