WeightStrategy

class torchrl.weight_update.WeightStrategy(extract_as: Literal['tensordict', 'state_dict'] = 'tensordict')[source]

Unified strategy for weight transmission.

This strategy handles both extraction and application of weights, supporting both TensorDict and state_dict formats.

Parameters:

extract_as (str) – Format for extracting weights. Can be: - “tensordict” (default): Extract weights as TensorDict - “state_dict”: Extract weights as PyTorch state_dict

The application format is automatically detected based on the type of weights received (dict -> state_dict, TensorDict -> tensordict).

apply_weights(destination: Any, weights: Any, inplace: bool = True) None[source]

Apply weights to destination model.

The format is automatically detected from the weights type: - dict -> state_dict format - TensorDictBase -> tensordict format

Parameters:
  • destination – The model to apply weights to. Can be: - nn.Module: PyTorch module - TensorDictBase: TensorDict - dict: State dictionary

  • weights – The weights to apply (dict or TensorDictBase).

  • inplace – Whether to apply weights in place.

extract_weights(source: Any) TensorDictBase | dict | None[source]

Extract weights from source model in the specified format.

Parameters:

source – The model to extract weights from. Can be: - nn.Module: PyTorch module - TensorDictBase: TensorDict - dict: State dictionary

Returns:

Weights in the format specified by extract_as constructor argument.

Docs

Lorem ipsum dolor sit amet, consectetur

View Docs

Tutorials

Lorem ipsum dolor sit amet, consectetur

View Tutorials

Resources

Lorem ipsum dolor sit amet, consectetur

View Resources