VanillaWeightUpdater¶
- class torchrl.collectors.VanillaWeightUpdater(*, weight_getter: Callable[[], TensorDictBase] | None = None, policy_weights: TensorDictBase)[source]¶
A simple implementation of
WeightUpdaterBase
for updating local policy weights.The VanillaWeightSender class provides a basic mechanism for updating the weights of a local policy by directly fetching them from a specified source. It is typically used in scenarios where the weight update logic is straightforward and does not require any complex mapping or transformation.
This class is used by default in the SyncDataCollector when no custom weight sender is provided.
See also
WeightUpdateReceiverBase
andSyncDataCollector
.- Keyword Arguments:
weight_getter (Callable[[], TensorDictBase], optional) – a callable that returns the weights from the server. If not provided, the weights must be passed to
update_weights()
directly.policy_weights (TensorDictBase) – a TensorDictBase containing the policy weights to be updated in-place.
- all_worker_ids() list[int] | list[torch.device] | None ¶
Returns a list of all worker identifiers or None if there are no workers associated.
- property collector: torch.collector.DataCollectorBase¶
The collector or container of the receiver.
Returns None if the container is out-of-scope or not set.
- push_weights(*, weights: Any | None = None, worker_ids: torch.device | int | list[int] | list[torch.device] | None = None)¶
Updates the weights of the policy, or on specified / all remote workers.
- Parameters:
weights (Any) – The source weights to push to the policy / workers.
worker_ids (torch.device | int | list[int] | list[torch.device] | None = None) – an optional list of workers to update.
Returns: nothing.
- register_collector(collector: DataCollectorBase)¶
Register a collector in the updater.
Once registered, the updater will not accept another collector.
- Parameters:
collector (DataCollectorBase) – The collector to register.