VanillaWeightUpdater¶
- class torchrl.collectors.VanillaWeightUpdater(*, weight_getter: Callable[[], TensorDictBase] | None = None, policy_weights: TensorDictBase)[source]¶
A simple implementation of
WeightUpdaterBase
for updating local policy weights.The VanillaWeightSender class provides a basic mechanism for updating the weights of a local policy by directly fetching them from a specified source. It is typically used in scenarios where the weight update logic is straightforward and does not require any complex mapping or transformation.
This class is used by default in the SyncDataCollector when no custom weight sender is provided.
See also
- Keyword Arguments:
weight_getter (Callable[[], TensorDictBase], optional) – a callable that returns the weights from the server. If not provided, the weights must be passed to
update_weights()
directly.policy_weights (TensorDictBase) – a TensorDictBase containing the policy weights to be updated in-place.
- all_worker_ids() list[int] | list[torch.device] | None ¶
Gets list of all worker IDs.
Returns None by default. Subclasses should override to return actual worker IDs.
- Returns:
List of worker IDs or None.
- Return type:
list[int] | list[torch.device] | None
- property collector: torch.collector.DataCollectorBase¶
The collector or container of the receiver.
Returns None if the container is out-of-scope or not set.
- classmethod from_policy(policy: TensorDictModuleBase) WeightUpdaterBase | None [source]¶
Creates a VanillaWeightUpdater instance from a policy.
This method creates a weight updater that will update the policy’s weights directly using its state dict.
- Parameters:
policy (TensorDictModuleBase) – The policy to create the weight updater from.
- Returns:
- An instance of the weight updater configured to update
the policy’s weights.
- Return type:
- increment_version()¶
Increment the policy version.
- init(*args, **kwargs)¶
Initialize the weight updater with custom arguments.
This method can be overridden by subclasses to handle custom initialization. By default, this is a no-op.
- Parameters:
*args – Positional arguments for initialization
**kwargs – Keyword arguments for initialization
- property post_hooks: list[Callable[[], NoneType]]¶
The list of post-hooks registered to the weight updater.
- push_weights(policy_or_weights: TensorDictModuleBase | TensorDictBase | dict | None = None, worker_ids: torch.device | int | list[int] | list[torch.device] | None = None)¶
Updates the weights of the policy, or on specified / all remote workers.
- Parameters:
policy_or_weights – The source to get weights from. Can be: - TensorDictModuleBase: A policy module whose weights will be extracted - TensorDictBase: A TensorDict containing weights - dict: A regular dict containing weights - None: Will try to get weights from server using _get_server_weights()
worker_ids – An optional list of workers to update.
Returns: nothing.
- register_collector(collector: DataCollectorBase)¶
Register a collector in the updater.
Once registered, the updater will not accept another collector.
- Parameters:
collector (DataCollectorBase) – The collector to register.
- register_post_hook(hook: Callable[[], None])¶
Registers a post-hook to be called after weights are updated.
- Parameters:
hook (Callable[[], None]) – The post-hook to register.