UpdateWeights¶
- class torchrl.trainers.UpdateWeights(collector: DataCollectorBase, update_weights_interval: int, policy_weights_getter: collections.abc.Callable[[Any], Any] | None = None, weight_update_map: dict[str, str] | None = None, trainer: torchrl.trainers.trainers.Trainer | None = None)[source]¶
A collector weights update hook class.
This hook must be used whenever the collector policy weights sit on a different device than the policy weights being trained by the Trainer. In that case, those weights must be synced across devices at regular intervals. If the devices match, this will result in a no-op.
- Parameters:
collector (DataCollectorBase) – A data collector where the policy weights must be synced.
update_weights_interval (int) – Interval (in terms of number of batches collected) where the sync must take place.
policy_weights_getter (Callable, optional) – A callable that returns the policy weights to sync. Used for backward compatibility. If both this and weight_update_map are provided, weight_update_map takes precedence.
weight_update_map (dict[str, str], optional) –
A mapping from destination paths (keys in collector’s weight_sync_schemes) to source paths on the trainer. Example: {“policy”: “loss_module.actor_network”,
”replay_buffer.transforms[0]”: “loss_module.critic_network”}
trainer (Trainer, optional) – The trainer instance, required when using weight_update_map to resolve source paths.
Examples
>>> # Legacy usage with policy_weights_getter >>> update_weights = UpdateWeights( ... trainer.collector, T, ... policy_weights_getter=lambda: TensorDict.from_module(policy) ... ) >>> trainer.register_op("post_steps", update_weights)
>>> # New usage with weight_update_map >>> update_weights = UpdateWeights( ... trainer.collector, T, ... weight_update_map={ ... "policy": "loss_module.actor_network", ... "replay_buffer.transforms[0]": "loss_module.critic_network" ... }, ... trainer=trainer ... ) >>> trainer.register_op("post_steps", update_weights)
- register(trainer: Trainer, name: str = 'update_weights')[source]¶
Registers the hook in the trainer at a default location.
- Parameters:
trainer (Trainer) – the trainer where the hook must be registered.
name (str) – the name of the hook.
Note
To register the hook at another location than the default, use
register_op().