Shortcuts

DataCollectorBase

class torchrl.collectors.DataCollectorBase[source]

Base class for data collectors.

async_shutdown(timeout: float | None = None) None[source]

Shuts down the collector when started asynchronously with the start method.

Arg:

timeout (float, optional): The maximum time to wait for the collector to shutdown.

See also

start()

pause()[source]

Context manager that pauses the collector if it is running free.

start()[source]

Starts the collector for asynchronous data collection.

This method initiates the background collection of data, allowing for decoupling of data collection and training.

The collected data is typically stored in a replay buffer passed during the collector’s initialization.

Note

After calling this method, it’s essential to shut down the collector using async_shutdown() when you’re done with it to free up resources.

Warning

Asynchronous data collection can significantly impact training performance due to its decoupled nature. Ensure you understand the implications for your specific algorithm before using this mode.

Raises:

NotImplementedError – If not implemented by a subclass.

update_policy_weights_(policy_weights: TensorDictBase | None = None, *, worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, **kwargs) None[source]

Updates the policy weights for the data collector, accommodating both local and remote execution contexts.

This method ensures that the policy weights used by the data collector are synchronized with the latest trained weights. It supports both local and remote weight updates, depending on the configuration of the data collector. The local (download) update is performed before the remote (upload) update, such that weights can be transferred to the children workers from a server.

Parameters:
  • policy_weights (TensorDictBase, optional) – A TensorDict containing the weights of the policy to be used for the update. If not provided, the method will attempt to fetch the weights using the configured weight updater.

  • worker_ids (int | List[int] | torch.device | List[torch.device] | None, optional) – Identifiers for the workers that need to be updated. This is relevant when the collector has more than one worker associated with it.

Raises:

TypeError – If worker_ids is provided but no weight_updater is configured.

Note

Users should extend the WeightUpdaterBase classes to customize the weight update logic for specific use cases. This method should not be overwritten.

See also

LocalWeightsUpdaterBase and RemoteWeightsUpdaterBase().

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources