DataCollectorBase¶
- class torchrl.collectors.DataCollectorBase[source]¶
Base class for data collectors.
- async_shutdown(timeout: float | None = None, close_env: bool = True) None [source]¶
Shuts down the collector when started asynchronously with the start method.
- Arg:
timeout (float, optional): The maximum time to wait for the collector to shutdown. close_env (bool, optional): If True, the collector will close the contained environment.
Defaults to True.
See also
- init_updater(*args, **kwargs)[source]¶
Initialize the weight updater with custom arguments.
This method passes the arguments to the weight updater’s init method. If no weight updater is set, this is a no-op.
- Parameters:
*args – Positional arguments for weight updater initialization
**kwargs – Keyword arguments for weight updater initialization
- start()[source]¶
Starts the collector for asynchronous data collection.
This method initiates the background collection of data, allowing for decoupling of data collection and training.
The collected data is typically stored in a replay buffer passed during the collector’s initialization.
Note
After calling this method, it’s essential to shut down the collector using
async_shutdown()
when you’re done with it to free up resources.Warning
Asynchronous data collection can significantly impact training performance due to its decoupled nature. Ensure you understand the implications for your specific algorithm before using this mode.
- Raises:
NotImplementedError – If not implemented by a subclass.
- update_policy_weights_(policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, *, worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, **kwargs) None [source]¶
Updates the policy weights for the data collector, accommodating both local and remote execution contexts.
This method ensures that the policy weights used by the data collector are synchronized with the latest trained weights. It supports both local and remote weight updates, depending on the configuration of the data collector. The local (download) update is performed before the remote (upload) update, such that weights can be transferred to the children workers from a server.
- Parameters:
policy_or_weights (TensorDictBase | TensorDictModuleBase | dict | None) – The weights to update with. Can be: - TensorDictModuleBase: A policy module whose weights will be extracted - TensorDictBase: A TensorDict containing weights - dict: A regular dict containing weights - None: Will try to get weights from server using _get_server_weights()
worker_ids (int | List[int] | torch.device | List[torch.device] | None, optional) – Identifiers for the workers that need to be updated. This is relevant when the collector has more than one worker associated with it.
- Raises:
TypeError – If worker_ids is provided but no weight_updater is configured.
Note
Users should extend the WeightUpdaterBase classes to customize the weight update logic for specific use cases. This method should not be overwritten.
See also
LocalWeightsUpdaterBase
andRemoteWeightsUpdaterBase()
.