MultiaSyncDataCollector¶
- class torchrl.collectors.MultiaSyncDataCollector(*args, **kwargs)[source]¶
Runs a given number of DataCollectors on separate processes asynchronously.
Environment types can be identical or different.
The collection keeps on occuring on all processes even between the time the batch of rollouts is collected and the next call to the iterator. This class can be safely used with offline RL sota-implementations.
Examples
>>> from torchrl.envs.libs.gym import GymEnv >>> from tensordict.nn import TensorDictModule >>> from torch import nn >>> env_maker = lambda: GymEnv("Pendulum-v1", device="cpu") >>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) >>> collector = MultiaSyncDataCollector( ... create_env_fn=[env_maker, env_maker], ... policy=policy, ... total_frames=2000, ... max_frames_per_traj=50, ... frames_per_batch=200, ... init_random_frames=-1, ... reset_at_each_iter=False, ... devices="cpu", ... storing_devices="cpu", ... ) >>> for i, data in enumerate(collector): ... if i == 2: ... print(data) ... break TensorDict( fields={ action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), collector: TensorDict( fields={ traj_ids: Tensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([200]), device=cpu, is_shared=False), done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([200]), device=cpu, is_shared=False), observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([200]), device=cpu, is_shared=False) >>> collector.shutdown() >>> del collector
Runs a given number of DataCollectors on separate processes.
- Parameters:
create_env_fn (List[Callabled]) – list of Callables, each returning an instance of
EnvBase.policy (Callable) –
Policy to be executed in the environment. Must accept
tensordict.tensordict.TensorDictBaseobject as input. IfNoneis provided, the policy used will be aRandomPolicyinstance with the environmentaction_spec. Accepted policies are usually subclasses ofTensorDictModuleBase. This is the recommended usage of the collector. Other callables are accepted too: If the policy is not aTensorDictModuleBase(e.g., a regularModuleinstances) it will be wrapped in a nn.Module first. Then, the collector will try to assess if these modules require wrapping in aTensorDictModuleor not. - If the policy forward signature matches any offorward(self, tensordict),forward(self, td)orforward(self, <anything>: TensorDictBase)(or any typing with a single argument typed as a subclass ofTensorDictBase) then the policy won’t be wrapped in aTensorDictModule.In all other cases an attempt to wrap it will be undergone as such:
TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys).
- Keyword Arguments:
frames_per_batch (int) – A keyword-only argument representing the total number of elements in a batch.
total_frames (int, optional) –
A keyword-only argument representing the total number of frames returned by the collector during its lifespan. If the
total_framesis not divisible byframes_per_batch, an exception is raised.Endless collectors can be created by passing
total_frames=-1. Defaults to-1(never ending collector).device (int, str or torch.device, optional) – The generic device of the collector. The
deviceargs fills any non-specified device: ifdeviceis notNoneand any ofstoring_device,policy_deviceorenv_deviceis not specified, its value will be set todevice. Defaults toNone(No default device). Supports a list of devices if one wishes to indicate a different device for each worker. The list must be as long as the number of workers.storing_device (int, str or torch.device, optional) – The device on which the output
TensorDictwill be stored. Ifdeviceis passed andstoring_deviceisNone, it will default to the value indicated bydevice. For long trajectories, it may be necessary to store the data on a different device than the one where the policy and env are executed. Defaults toNone(the output tensordict isn’t on a specific device, leaf tensors sit on the device where they were created). Supports a list of devices if one wishes to indicate a different device for each worker. The list must be as long as the number of workers.env_device (int, str or torch.device, optional) – The device on which the environment should be cast (or executed if that functionality is supported). If not specified and the env has a non-
Nonedevice,env_devicewill default to that value. Ifdeviceis passed andenv_device=None, it will default todevice. If the value as such specified ofenv_devicediffers frompolicy_deviceand one of them is notNone, the data will be cast toenv_devicebefore being passed to the env (i.e., passing different devices to policy and env is supported). Defaults toNone. Supports a list of devices if one wishes to indicate a different device for each worker. The list must be as long as the number of workers.policy_device (int, str or torch.device, optional) – The device on which the policy should be cast. If
deviceis passed andpolicy_device=None, it will default todevice. If the value as such specified ofpolicy_devicediffers fromenv_deviceand one of them is notNone, the data will be cast topolicy_devicebefore being passed to the policy (i.e., passing different devices to policy and env is supported). Defaults toNone. Supports a list of devices if one wishes to indicate a different device for each worker. The list must be as long as the number of workers.create_env_kwargs (dict, optional) – A dictionary with the keyword arguments used to create an environment. If a list is provided, each of its elements will be assigned to a sub-collector.
max_frames_per_traj (int, optional) – Maximum steps per trajectory. Note that a trajectory can span across multiple batches (unless
reset_at_each_iteris set toTrue, see below). Once a trajectory reachesn_steps, the environment is reset. If the environment wraps multiple environments together, the number of steps is tracked for each environment independently. Negative values are allowed, in which case this argument is ignored. Defaults toNone(i.e. no maximum number of steps).init_random_frames (int, optional) – Number of frames for which the policy is ignored before it is called. This feature is mainly intended to be used in offline/model-based settings, where a batch of random trajectories can be used to initialize training. If provided, it will be rounded up to the closest multiple of frames_per_batch. Defaults to
None(i.e. no random frames).reset_at_each_iter (bool, optional) – Whether environments should be reset at the beginning of a batch collection. Defaults to
False.postproc (Callable, optional) – A post-processing transform, such as a
Transformor aMultiStepinstance. Defaults toNone.split_trajs (bool, optional) – Boolean indicating whether the resulting TensorDict should be split according to the trajectories. See
split_trajectories()for more information. Defaults toFalse.exploration_type (ExplorationType, optional) – interaction mode to be used when collecting data. Must be one of
torchrl.envs.utils.ExplorationType.RANDOM,torchrl.envs.utils.ExplorationType.MODEortorchrl.envs.utils.ExplorationType.MEAN. Defaults totorchrl.envs.utils.ExplorationType.RANDOM.reset_when_done (bool, optional) – if
True(default), an environment that return aTruevalue in its"done"or"truncated"entry will be reset at the corresponding indices.update_at_each_batch (boolm optional) – if
True,update_policy_weight_()will be called before (sync) or after (async) each data collection. Defaults toFalse.preemptive_threshold (float, optional) – a value between 0.0 and 1.0 that specifies the ratio of workers that will be allowed to finished collecting their rollout before the rest are forced to end early.
num_threads (int, optional) – number of threads for this process. Defaults to the number of workers.
num_sub_threads (int, optional) – number of threads of the subprocesses. Should be equal to one plus the number of processes launched within each subprocess (or one if a single process is launched). Defaults to 1 for safety: if none is indicated, launching multiple workers may charge the cpu load too much and harm performance.
cat_results (str, int or None) –
(
MultiSyncDataCollectorexclusively). If"stack", the data collected from the workers will be stacked along the first dimension. This is the preferred behaviour as it is the most compatible with the rest of the library. If0, results will be concatenated along the first dimension of the outputs, which can be the batched dimension if the environments are batched or the time dimension if not. Acat_resultsvalue of-1will always concatenate results along the time dimension. This should be preferred over the default. Intermediate values are also accepted. Defaults to0.Note
From v0.5, this argument will default to
"stack"for a better interoperability with the rest of the library.set_truncated (bool, optional) – if
True, the truncated signals (and corresponding"done"but not"terminated") will be set toTruewhen the last frame of a rollout is reached. If no"truncated"key is found, an exception is raised. Truncated keys can be set throughenv.add_truncated_keys. Defaults toFalse.
- load_state_dict(state_dict: OrderedDict) None[source]¶
Loads the state_dict on the workers.
- Parameters:
state_dict (OrderedDict) – state_dict of the form
{"worker0": state_dict0, "worker1": state_dict1}.
- reset(reset_idx: Optional[Sequence[bool]] = None) None[source]¶
Resets the environments to a new initial state.
- Parameters:
reset_idx – Optional. Sequence indicating which environments have to be reset. If None, all environments are reset.
- set_seed(seed: int, static_seed: bool = False) int[source]¶
Sets the seeds of the environments stored in the DataCollector.
- Parameters:
seed – integer representing the seed to be used for the environment.
static_seed (bool, optional) – if
True, the seed is not incremented. Defaults to False
- Returns:
Output seed. This is useful when more than one environment is contained in the DataCollector, as the seed will be incremented for each of these. The resulting seed is the seed of the last environment.
Examples
>>> from torchrl.envs import ParallelEnv >>> from torchrl.envs.libs.gym import GymEnv >>> from tensordict.nn import TensorDictModule >>> from torch import nn >>> env_fn = lambda: GymEnv("Pendulum-v1") >>> env_fn_parallel = lambda: ParallelEnv(6, env_fn) >>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) >>> collector = SyncDataCollector(env_fn_parallel, policy, frames_per_batch=100, total_frames=300) >>> out_seed = collector.set_seed(1) # out_seed = 6
- state_dict() OrderedDict[source]¶
Returns the state_dict of the data collector.
Each field represents a worker containing its own state_dict.
- update_policy_weights_(policy_weights: Optional[TensorDictBase] = None) None[source]¶
Updates the policy weights if the policy of the data collector and the trained policy live on different devices.
- Parameters:
policy_weights (TensorDictBase, optional) – if provided, a TensorDict containing the weights of the policy to be used for the udpdate.