Shortcuts

VanillaWeightUpdater

class torchrl.collectors.VanillaWeightUpdater(*, weight_getter: Callable[[], TensorDictBase] | None = None, policy_weights: TensorDictBase)[source]

A simple implementation of WeightUpdaterBase for updating local policy weights.

The VanillaWeightSender class provides a basic mechanism for updating the weights of a local policy by directly fetching them from a specified source. It is typically used in scenarios where the weight update logic is straightforward and does not require any complex mapping or transformation.

This class is used by default in the SyncDataCollector when no custom weight sender is provided.

Keyword Arguments:
  • weight_getter (Callable[[], TensorDictBase], optional) – a callable that returns the weights from the server. If not provided, the weights must be passed to update_weights() directly.

  • policy_weights (TensorDictBase) – a TensorDictBase containing the policy weights to be updated in-place.

all_worker_ids() list[int] | list[torch.device] | None

Gets list of all worker IDs.

Returns None by default. Subclasses should override to return actual worker IDs.

Returns:

List of worker IDs or None.

Return type:

list[int] | list[torch.device] | None

property collector: torch.collector.DataCollectorBase

The collector or container of the receiver.

Returns None if the container is out-of-scope or not set.

classmethod from_policy(policy: TensorDictModuleBase) WeightUpdaterBase | None[source]

Creates a VanillaWeightUpdater instance from a policy.

This method creates a weight updater that will update the policy’s weights directly using its state dict.

Parameters:

policy (TensorDictModuleBase) – The policy to create the weight updater from.

Returns:

An instance of the weight updater configured to update

the policy’s weights.

Return type:

VanillaWeightUpdater

increment_version()

Increment the policy version.

init(*args, **kwargs)

Initialize the weight updater with custom arguments.

This method can be overridden by subclasses to handle custom initialization. By default, this is a no-op.

Parameters:
  • *args – Positional arguments for initialization

  • **kwargs – Keyword arguments for initialization

property post_hooks: list[Callable[[], NoneType]]

The list of post-hooks registered to the weight updater.

push_weights(policy_or_weights: TensorDictModuleBase | TensorDictBase | dict | None = None, worker_ids: torch.device | int | list[int] | list[torch.device] | None = None)

Updates the weights of the policy, or on specified / all remote workers.

Parameters:
  • policy_or_weights – The source to get weights from. Can be: - TensorDictModuleBase: A policy module whose weights will be extracted - TensorDictBase: A TensorDict containing weights - dict: A regular dict containing weights - None: Will try to get weights from server using _get_server_weights()

  • worker_ids – An optional list of workers to update.

Returns: nothing.

register_collector(collector: DataCollectorBase)

Register a collector in the updater.

Once registered, the updater will not accept another collector.

Parameters:

collector (DataCollectorBase) – The collector to register.

register_post_hook(hook: Callable[[], None])

Registers a post-hook to be called after weights are updated.

Parameters:

hook (Callable[[], None]) – The post-hook to register.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources