Shortcuts

DataCollectorBase

class torchrl.collectors.DataCollectorBase[source]

Base class for data collectors.

update_policy_weights_(policy_weights: TensorDictBase | None = None, *, worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, **kwargs) None[source]

Updates the policy weights for the data collector, accommodating both local and remote execution contexts.

This method ensures that the policy weights used by the data collector are synchronized with the latest trained weights. It supports both local and remote weight updates, depending on the configuration of the data collector. The local (download) update is performed before the remote (upload) update, such that weights can be transferred to the children workers from a server.

Parameters:
  • policy_weights (TensorDictBase, optional) – A TensorDict containing the weights of the policy to be used for the update. If not provided, the method will attempt to fetch the weights using the configured weight updater.

  • worker_ids (int | List[int] | torch.device | List[torch.device] | None, optional) – Identifiers for the workers that need to be updated. This is relevant when the collector has more than one worker associated with it.

Raises:

TypeError – If worker_ids is provided but no weight_updater is configured.

Note

Users should extend the WeightUpdaterBase classes to customize the weight update logic for specific use cases. This method should not be overwritten.

See also

LocalWeightsUpdaterBase and RemoteWeightsUpdaterBase().

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources