Shortcuts

WeightUpdaterBase

class torchrl.collectors.WeightUpdaterBase[source]

A base class for updating remote policy weights on inference workers.

The weight updater is the central piece of the weight update scheme:

  • In leaf collector nodes, it is responsible for sending the weights to the policy, which can be as simple as updating a state-dict, or more complex if an inference server is being used.

  • In server collector nodes, it is responsible for sending the weights to the leaf collectors.

In a collector, the updater is called within update_policy_weights_().`

The main method of this class is the push_weights() method, which updates the policy weights in the worker / policy.

To extend this class, implement the following abstract methods:

  • _get_server_weights (optional): Define how to retrieve the weights from the server if they are not passed to

    the updater directly. This method is only called if the weights (hanlde) is not passed directly.

  • _sync_weights_with_worker: Define how to synchronize weights with a specific worker.

    This method must be implemented by child classes.

  • _maybe_map_weights: Optionally transform the server weights before distribution.

    By default, this method returns the weights unchanged.

  • all_worker_ids: Provide a list of all worker identifiers.

    Returns None by default (no worker id).

Variables:

collector – The collector (or any container) of the weight receiver. The collector is registered via register_collector().

push_weights()[source]

Updates the weights on specified or all remote workers. The __call__ method is a proxy to push_weights.

register_collector()[source]

Registers the collector (or any container) in the receiver through a weakref. This will be called automatically by the collector upon registration of the updater.

all_worker_ids() list[int] | list[torch.device] | None[source]

Returns a list of all worker identifiers or None if there are no workers associated.

property collector: torch.collector.DataCollectorBase

The collector or container of the receiver.

Returns None if the container is out-of-scope or not set.

push_weights(*, weights: Any | None = None, worker_ids: torch.device | int | list[int] | list[torch.device] | None = None)[source]

Updates the weights of the policy, or on specified / all remote workers.

Parameters:
  • weights (Any) – The source weights to push to the policy / workers.

  • worker_ids (torch.device | int | list[int] | list[torch.device] | None = None) – an optional list of workers to update.

Returns: nothing.

register_collector(collector: DataCollectorBase)[source]

Register a collector in the updater.

Once registered, the updater will not accept another collector.

Parameters:

collector (DataCollectorBase) – The collector to register.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources