DataCollectorBase¶
- class torchrl.collectors.DataCollectorBase[source]¶
Base class for data collectors.
- update_policy_weights_(policy_weights: TensorDictBase | None = None, *, worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, **kwargs) None [source]¶
Updates the policy weights for the data collector, accommodating both local and remote execution contexts.
This method ensures that the policy weights used by the data collector are synchronized with the latest trained weights. It supports both local and remote weight updates, depending on the configuration of the data collector. The local (download) update is performed before the remote (upload) update, such that weights can be transferred to the children workers from a server.
- Parameters:
policy_weights (TensorDictBase, optional) – A TensorDict containing the weights of the policy to be used for the update. If not provided, the method will attempt to fetch the weights using the configured weight updater.
worker_ids (int | List[int] | torch.device | List[torch.device] | None, optional) – Identifiers for the workers that need to be updated. This is relevant when the collector has more than one worker associated with it.
- Raises:
TypeError – If worker_ids is provided but no weight_updater is configured.
Note
Users should extend the WeightUpdaterBase classes to customize the weight update logic for specific use cases. This method should not be overwritten.
See also
LocalWeightsUpdaterBase
andRemoteWeightsUpdaterBase()
.