Shortcuts

WeightUpdaterBase

class torchrl.collectors.WeightUpdaterBase[source]

A base class for updating remote policy weights on inference workers.

The weight updater is the central piece of the weight update scheme:

  • In leaf collector nodes, it is responsible for sending the weights to the policy, which can be as simple as updating a state-dict, or more complex if an inference server is being used.

  • In server collector nodes, it is responsible for sending the weights to the leaf collectors.

In a collector, the updater is called within update_policy_weights_().`

The main method of this class is the _push_weights() method, which updates the policy weights in the worker / policy. This method is called by push_weights(), which also calls the post-hooks: only _push_weights should be implemented by child classes.

To extend this class, implement the following abstract methods:

  • _get_server_weights (optional): Define how to retrieve the weights from the server if they are not passed to

    the updater directly. This method is only called if the weights (handle) is not passed directly.

  • _sync_weights_with_worker: Define how to synchronize weights with a specific worker.

    This method must be implemented by child classes.

  • _maybe_map_weights: Optionally transform the server weights before distribution.

    By default, this method returns the weights unchanged.

  • all_worker_ids: Provide a list of all worker identifiers.

    Returns None by default (no worker id).

  • from_policy (optional classmethod): Define how to create an instance of the weight updater from a policy.

    If implemented, this method will be called before falling back to the default constructor when initializing a weight updater in a collector.

Variables:

collector – The collector (or any container) of the weight receiver. The collector is registered via register_collector().

push_weights()[source]

Updates the weights on specified or all remote workers. The __call__ method is a proxy to push_weights.

register_collector()[source]

Registers the collector (or any container) in the receiver through a weakref. This will be called automatically by the collector upon registration of the updater.

from_policy()[source]

Optional classmethod to create an instance from a policy.

Post-hooks:
  • register_post_hook: Registers a post-hook to be called after the weights are updated.

    The post-hook must be a callable that takes no arguments. The post-hook will be called after the weights are updated. The post-hook will be called in the same process as the weight updater. The post-hook will be called in the same order as the post-hooks were registered.

all_worker_ids() list[int] | list[torch.device] | None[source]

Gets list of all worker IDs.

Returns None by default. Subclasses should override to return actual worker IDs.

Returns:

List of worker IDs or None.

Return type:

list[int] | list[torch.device] | None

property collector: torch.collector.DataCollectorBase

The collector or container of the receiver.

Returns None if the container is out-of-scope or not set.

classmethod from_policy(policy: TensorDictModuleBase) WeightUpdaterBase | None[source]

Optional classmethod to create a weight updater instance from a policy.

This method can be implemented by subclasses to provide custom initialization logic based on the policy. If implemented, this method will be called before falling back to the default constructor when initializing a weight updater in a collector.

Parameters:

policy (TensorDictModuleBase) – The policy to create the weight updater from.

Returns:

An instance of the weight updater, or None if the policy

cannot be used to create an instance.

Return type:

WeightUpdaterBase | None

increment_version()[source]

Increment the policy version.

init(*args, **kwargs)[source]

Initialize the weight updater with custom arguments.

This method can be overridden by subclasses to handle custom initialization. By default, this is a no-op.

Parameters:
  • *args – Positional arguments for initialization

  • **kwargs – Keyword arguments for initialization

property post_hooks: list[Callable[[], NoneType]]

The list of post-hooks registered to the weight updater.

push_weights(policy_or_weights: TensorDictModuleBase | TensorDictBase | dict | None = None, worker_ids: torch.device | int | list[int] | list[torch.device] | None = None)[source]

Updates the weights of the policy, or on specified / all remote workers.

Parameters:
  • policy_or_weights – The source to get weights from. Can be: - TensorDictModuleBase: A policy module whose weights will be extracted - TensorDictBase: A TensorDict containing weights - dict: A regular dict containing weights - None: Will try to get weights from server using _get_server_weights()

  • worker_ids – An optional list of workers to update.

Returns: nothing.

register_collector(collector: DataCollectorBase)[source]

Register a collector in the updater.

Once registered, the updater will not accept another collector.

Parameters:

collector (DataCollectorBase) – The collector to register.

register_post_hook(hook: Callable[[], None])[source]

Registers a post-hook to be called after weights are updated.

Parameters:

hook (Callable[[], None]) – The post-hook to register.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources