Shortcuts

Source code for torchrl.trainers.algorithms.configs.modules

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from dataclasses import dataclass, field
from functools import partial
from typing import Any

import torch

from omegaconf import MISSING

from torchrl.trainers.algorithms.configs.common import ConfigBase


@dataclass
class ActivationConfig(ConfigBase):
    """A class to configure an activation function.

    Defaults to :class:`torch.nn.Tanh`.

    .. seealso:: :class:`torch.nn.Tanh`
    """

    _target_: str = "torch.nn.Tanh"
    _partial_: bool = False

    def __post_init__(self) -> None:
        """Post-initialization hook for activation configurations."""


@dataclass
class LayerConfig(ConfigBase):
    """A class to configure a layer.

    Defaults to :class:`torch.nn.Linear`.

    .. seealso:: :class:`torch.nn.Linear`
    """

    _target_: str = "torch.nn.Linear"
    _partial_: bool = False

    def __post_init__(self) -> None:
        """Post-initialization hook for layer configurations."""


[docs]@dataclass class NetworkConfig(ConfigBase): """Parent class to configure a network.""" _partial_: bool = False def __post_init__(self) -> None: """Post-initialization hook for network configurations."""
[docs]@dataclass class MLPConfig(NetworkConfig): """A class to configure a multi-layer perceptron. Example: >>> cfg = MLPConfig(in_features=10, out_features=5, depth=2, num_cells=32) >>> net = instantiate(cfg) >>> y = net(torch.randn(1, 10)) >>> assert y.shape == (1, 5) .. seealso:: :class:`torchrl.modules.MLP` """ in_features: int | None = None out_features: Any = None depth: int | None = None num_cells: Any = None activation_class: ActivationConfig = field( default_factory=partial( ActivationConfig, _target_="torch.nn.Tanh", _partial_=True ) ) activation_kwargs: Any = None norm_class: Any = None norm_kwargs: Any = None dropout: float | None = None bias_last_layer: bool = True single_bias_last_layer: bool = False layer_class: LayerConfig = field( default_factory=partial(LayerConfig, _target_="torch.nn.Linear", _partial_=True) ) layer_kwargs: dict | None = None activate_last_layer: bool = False device: Any = None _target_: str = "torchrl.modules.MLP" def __post_init__(self): if isinstance(self.activation_class, str): self.activation_class = ActivationConfig( _target_=self.activation_class, _partial_=True ) if isinstance(self.layer_class, str): self.layer_class = LayerConfig(_target_=self.layer_class, _partial_=True)
@dataclass class NormConfig(ConfigBase): """A class to configure a normalization layer. Defaults to :class:`torch.nn.BatchNorm1d`. .. seealso:: :class:`torch.nn.BatchNorm1d` """ _target_: str = "torch.nn.BatchNorm1d" _partial_: bool = False def __post_init__(self) -> None: """Post-initialization hook for normalization configurations.""" @dataclass class AggregatorConfig(ConfigBase): """A class to configure an aggregator layer. Defaults to :class:`torchrl.modules.models.utils.SquashDims`. .. seealso:: :class:`torchrl.modules.models.utils.SquashDims` """ _target_: str = "torchrl.modules.models.utils.SquashDims" _partial_: bool = False def __post_init__(self) -> None: """Post-initialization hook for aggregator configurations."""
[docs]@dataclass class ConvNetConfig(NetworkConfig): """A class to configure a convolutional network. Defaults to :class:`torchrl.modules.ConvNet`. Example: >>> cfg = ConvNetConfig(in_features=3, depth=2, num_cells=[32, 64], kernel_sizes=[3, 5], strides=[1, 2], paddings=[1, 2]) >>> net = instantiate(cfg) >>> y = net(torch.randn(1, 3, 32, 32)) >>> assert y.shape == (1, 64) .. seealso:: :class:`torchrl.modules.ConvNet` """ in_features: int | None = None depth: int | None = None num_cells: Any = None kernel_sizes: Any = 3 strides: Any = 1 paddings: Any = 0 activation_class: ActivationConfig = field( default_factory=partial( ActivationConfig, _target_="torch.nn.ELU", _partial_=True ) ) activation_kwargs: Any = None norm_class: NormConfig | None = None norm_kwargs: Any = None bias_last_layer: bool = True aggregator_class: AggregatorConfig = field( default_factory=partial( AggregatorConfig, _target_="torchrl.modules.models.utils.SquashDims", _partial_=True, ) ) aggregator_kwargs: dict | None = None squeeze_output: bool = False device: Any = None _target_: str = "torchrl.modules.ConvNet" def __post_init__(self): if self.activation_class is None and isinstance(self.activation_class, str): self.activation_class = ActivationConfig( _target_=self.activation_class, _partial_=True ) if self.norm_class is None and isinstance(self.norm_class, str): self.norm_class = NormConfig(_target_=self.norm_class, _partial_=True) if self.aggregator_class is None and isinstance(self.aggregator_class, str): self.aggregator_class = AggregatorConfig( _target_=self.aggregator_class, _partial_=True )
[docs]@dataclass class ModelConfig(ConfigBase): """Parent class to configure a model. A model can be made of several networks. It is always a :class:`~tensordict.nn.TensorDictModuleBase` instance. .. seealso:: :class:`TanhNormalModelConfig`, :class:`ValueModelConfig` """ _partial_: bool = False in_keys: Any = None out_keys: Any = None def __post_init__(self) -> None: """Post-initialization hook for model configurations."""
[docs]@dataclass class TensorDictModuleConfig(ModelConfig): """A class to configure a TensorDictModule. Example: >>> cfg = TensorDictModuleConfig(module=MLPConfig(in_features=10, out_features=10, depth=2, num_cells=32), in_keys=["observation"], out_keys=["action"]) >>> module = instantiate(cfg) >>> assert isinstance(module, TensorDictModule) >>> assert module(observation=torch.randn(10, 10)).shape == (10, 10) .. seealso:: :class:`tensordict.nn.TensorDictModule` """ module: MLPConfig = MISSING _target_: str = "tensordict.nn.TensorDictModule" _partial_: bool = False def __post_init__(self) -> None: """Post-initialization hook for TensorDict module configurations.""" super().__post_init__()
[docs]@dataclass class TanhNormalModelConfig(ModelConfig): """A class to configure a TanhNormal model. Example: >>> cfg = TanhNormalModelConfig(network=MLPConfig(in_features=10, out_features=5, depth=2, num_cells=32)) >>> net = instantiate(cfg) >>> y = net(torch.randn(1, 10)) >>> assert y.shape == (1, 5) .. seealso:: :class:`torchrl.modules.TanhNormal` """ network: MLPConfig = MISSING eval_mode: bool = False extract_normal_params: bool = True scale_mapping: str = "biased_softplus_1.0" scale_lb: float = 1e-4 param_keys: Any = None exploration_type: Any = "RANDOM" return_log_prob: bool = False _target_: str = ( "torchrl.trainers.algorithms.configs.modules._make_tanh_normal_model" ) def __post_init__(self): """Post-initialization hook for TanhNormal model configurations.""" super().__post_init__() if self.in_keys is None: self.in_keys = ["observation"] if self.param_keys is None: self.param_keys = ["loc", "scale"] if self.out_keys is None: self.out_keys = ["action"]
[docs]@dataclass class ValueModelConfig(ModelConfig): """A class to configure a Value model. Example: >>> cfg = ValueModelConfig(network=MLPConfig(in_features=10, out_features=5, depth=2, num_cells=32)) >>> net = instantiate(cfg) >>> y = net(torch.randn(1, 10)) >>> assert y.shape == (1, 5) .. seealso:: :class:`torchrl.modules.ValueOperator` """ _target_: str = "torchrl.trainers.algorithms.configs.modules._make_value_model" network: NetworkConfig = MISSING def __post_init__(self) -> None: """Post-initialization hook for value model configurations.""" super().__post_init__()
def _make_tanh_normal_model(*args, **kwargs): """Helper function to create a TanhNormal model with ProbabilisticTensorDictSequential.""" from hydra.utils import instantiate from tensordict.nn import ( ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential, TensorDictModule, ) from torchrl.modules import NormalParamExtractor, TanhNormal # Extract parameters network = kwargs.pop("network") in_keys = list(kwargs.pop("in_keys", ["observation"])) param_keys = list(kwargs.pop("param_keys", ["loc", "scale"])) out_keys = list(kwargs.pop("out_keys", ["action"])) extract_normal_params = kwargs.pop("extract_normal_params", True) scale_mapping = kwargs.pop("scale_mapping", "biased_softplus_1.0") scale_lb = kwargs.pop("scale_lb", 1e-4) return_log_prob = kwargs.pop("return_log_prob", False) eval_mode = kwargs.pop("eval_mode", False) exploration_type = kwargs.pop("exploration_type", "RANDOM") # Now instantiate the network if hasattr(network, "_target_"): network = instantiate(network) elif callable(network) and hasattr(network, "func"): # partial function network = network() # Create the sequential if extract_normal_params: # Add NormalParamExtractor to split the output network = torch.nn.Sequential( network, NormalParamExtractor(scale_mapping=scale_mapping, scale_lb=scale_lb), ) module = TensorDictModule(network, in_keys=in_keys, out_keys=param_keys) # Create ProbabilisticTensorDictModule prob_module = ProbabilisticTensorDictModule( in_keys=param_keys, out_keys=out_keys, distribution_class=TanhNormal, return_log_prob=return_log_prob, default_interaction_type=exploration_type, **kwargs, ) result = ProbabilisticTensorDictSequential(module, prob_module) if eval_mode: result.eval() return result def _make_value_model(*args, **kwargs): """Helper function to create a ValueOperator with the given network.""" from torchrl.modules import ValueOperator network = kwargs.pop("network") return ValueOperator(network, **kwargs)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources