Shortcuts

WeightStrategy

class torchrl.weight_update.WeightStrategy(extract_as: Literal['tensordict', 'state_dict'] = 'tensordict')[source]

Unified strategy for weight transmission.

This strategy handles both extraction and application of weights, supporting both TensorDict and state_dict formats.

Parameters:

extract_as (str) – Format for extracting weights. Can be: - “tensordict” (default): Extract weights as TensorDict - “state_dict”: Extract weights as PyTorch state_dict

The application format is automatically detected based on the type of weights received (dict -> state_dict, TensorDict -> tensordict).

apply_weights(destination: Any, weights: Any, inplace: bool = True) None[source]

Apply weights to destination model.

The format is automatically detected from the weights type: - dict -> state_dict format - TensorDictBase -> tensordict format

Parameters:
  • destination – The model to apply weights to. Can be: - nn.Module: PyTorch module - TensorDictBase: TensorDict - dict: State dictionary

  • weights – The weights to apply (dict or TensorDictBase).

  • inplace – Whether to apply weights in place.

extract_weights(source: Any) tensordict.base.TensorDictBase | dict | None[source]

Extract weights from source model in the specified format.

Parameters:

source – The model to extract weights from. Can be: - nn.Module: PyTorch module - TensorDictBase: TensorDict - dict: State dictionary

Returns:

Weights in the format specified by extract_as constructor argument.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources