Source code for torchrl.trainers.algorithms.configs.modules
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
from dataclasses import dataclass, field
from functools import partial
from typing import Any
import torch
from omegaconf import MISSING
from torchrl.trainers.algorithms.configs.common import ConfigBase
@dataclass
class ActivationConfig(ConfigBase):
"""A class to configure an activation function.
Defaults to :class:`torch.nn.Tanh`.
.. seealso:: :class:`torch.nn.Tanh`
"""
_target_: str = "torch.nn.Tanh"
_partial_: bool = False
def __post_init__(self) -> None:
"""Post-initialization hook for activation configurations."""
@dataclass
class LayerConfig(ConfigBase):
"""A class to configure a layer.
Defaults to :class:`torch.nn.Linear`.
.. seealso:: :class:`torch.nn.Linear`
"""
_target_: str = "torch.nn.Linear"
_partial_: bool = False
def __post_init__(self) -> None:
"""Post-initialization hook for layer configurations."""
[docs]@dataclass
class NetworkConfig(ConfigBase):
"""Parent class to configure a network."""
_partial_: bool = False
def __post_init__(self) -> None:
"""Post-initialization hook for network configurations."""
[docs]@dataclass
class MLPConfig(NetworkConfig):
"""A class to configure a multi-layer perceptron.
Example:
>>> cfg = MLPConfig(in_features=10, out_features=5, depth=2, num_cells=32)
>>> net = instantiate(cfg)
>>> y = net(torch.randn(1, 10))
>>> assert y.shape == (1, 5)
.. seealso:: :class:`torchrl.modules.MLP`
"""
in_features: int | None = None
out_features: Any = None
depth: int | None = None
num_cells: Any = None
activation_class: ActivationConfig = field(
default_factory=partial(
ActivationConfig, _target_="torch.nn.Tanh", _partial_=True
)
)
activation_kwargs: Any = None
norm_class: Any = None
norm_kwargs: Any = None
dropout: float | None = None
bias_last_layer: bool = True
single_bias_last_layer: bool = False
layer_class: LayerConfig = field(
default_factory=partial(LayerConfig, _target_="torch.nn.Linear", _partial_=True)
)
layer_kwargs: dict | None = None
activate_last_layer: bool = False
device: Any = None
_target_: str = "torchrl.modules.MLP"
def __post_init__(self):
if isinstance(self.activation_class, str):
self.activation_class = ActivationConfig(
_target_=self.activation_class, _partial_=True
)
if isinstance(self.layer_class, str):
self.layer_class = LayerConfig(_target_=self.layer_class, _partial_=True)
@dataclass
class NormConfig(ConfigBase):
"""A class to configure a normalization layer.
Defaults to :class:`torch.nn.BatchNorm1d`.
.. seealso:: :class:`torch.nn.BatchNorm1d`
"""
_target_: str = "torch.nn.BatchNorm1d"
_partial_: bool = False
def __post_init__(self) -> None:
"""Post-initialization hook for normalization configurations."""
@dataclass
class AggregatorConfig(ConfigBase):
"""A class to configure an aggregator layer.
Defaults to :class:`torchrl.modules.models.utils.SquashDims`.
.. seealso:: :class:`torchrl.modules.models.utils.SquashDims`
"""
_target_: str = "torchrl.modules.models.utils.SquashDims"
_partial_: bool = False
def __post_init__(self) -> None:
"""Post-initialization hook for aggregator configurations."""
[docs]@dataclass
class ConvNetConfig(NetworkConfig):
"""A class to configure a convolutional network.
Defaults to :class:`torchrl.modules.ConvNet`.
Example:
>>> cfg = ConvNetConfig(in_features=3, depth=2, num_cells=[32, 64], kernel_sizes=[3, 5], strides=[1, 2], paddings=[1, 2])
>>> net = instantiate(cfg)
>>> y = net(torch.randn(1, 3, 32, 32))
>>> assert y.shape == (1, 64)
.. seealso:: :class:`torchrl.modules.ConvNet`
"""
in_features: int | None = None
depth: int | None = None
num_cells: Any = None
kernel_sizes: Any = 3
strides: Any = 1
paddings: Any = 0
activation_class: ActivationConfig = field(
default_factory=partial(
ActivationConfig, _target_="torch.nn.ELU", _partial_=True
)
)
activation_kwargs: Any = None
norm_class: NormConfig | None = None
norm_kwargs: Any = None
bias_last_layer: bool = True
aggregator_class: AggregatorConfig = field(
default_factory=partial(
AggregatorConfig,
_target_="torchrl.modules.models.utils.SquashDims",
_partial_=True,
)
)
aggregator_kwargs: dict | None = None
squeeze_output: bool = False
device: Any = None
_target_: str = "torchrl.modules.ConvNet"
def __post_init__(self):
if self.activation_class is None and isinstance(self.activation_class, str):
self.activation_class = ActivationConfig(
_target_=self.activation_class, _partial_=True
)
if self.norm_class is None and isinstance(self.norm_class, str):
self.norm_class = NormConfig(_target_=self.norm_class, _partial_=True)
if self.aggregator_class is None and isinstance(self.aggregator_class, str):
self.aggregator_class = AggregatorConfig(
_target_=self.aggregator_class, _partial_=True
)
[docs]@dataclass
class ModelConfig(ConfigBase):
"""Parent class to configure a model.
A model can be made of several networks. It is always a :class:`~tensordict.nn.TensorDictModuleBase` instance.
.. seealso:: :class:`TanhNormalModelConfig`, :class:`ValueModelConfig`
"""
_partial_: bool = False
in_keys: Any = None
out_keys: Any = None
def __post_init__(self) -> None:
"""Post-initialization hook for model configurations."""
[docs]@dataclass
class TensorDictModuleConfig(ModelConfig):
"""A class to configure a TensorDictModule.
Example:
>>> cfg = TensorDictModuleConfig(module=MLPConfig(in_features=10, out_features=10, depth=2, num_cells=32), in_keys=["observation"], out_keys=["action"])
>>> module = instantiate(cfg)
>>> assert isinstance(module, TensorDictModule)
>>> assert module(observation=torch.randn(10, 10)).shape == (10, 10)
.. seealso:: :class:`tensordict.nn.TensorDictModule`
"""
module: MLPConfig = MISSING
_target_: str = "tensordict.nn.TensorDictModule"
_partial_: bool = False
def __post_init__(self) -> None:
"""Post-initialization hook for TensorDict module configurations."""
super().__post_init__()
[docs]@dataclass
class TanhNormalModelConfig(ModelConfig):
"""A class to configure a TanhNormal model.
Example:
>>> cfg = TanhNormalModelConfig(network=MLPConfig(in_features=10, out_features=5, depth=2, num_cells=32))
>>> net = instantiate(cfg)
>>> y = net(torch.randn(1, 10))
>>> assert y.shape == (1, 5)
.. seealso:: :class:`torchrl.modules.TanhNormal`
"""
network: MLPConfig = MISSING
eval_mode: bool = False
extract_normal_params: bool = True
scale_mapping: str = "biased_softplus_1.0"
scale_lb: float = 1e-4
param_keys: Any = None
exploration_type: Any = "RANDOM"
return_log_prob: bool = False
_target_: str = (
"torchrl.trainers.algorithms.configs.modules._make_tanh_normal_model"
)
def __post_init__(self):
"""Post-initialization hook for TanhNormal model configurations."""
super().__post_init__()
if self.in_keys is None:
self.in_keys = ["observation"]
if self.param_keys is None:
self.param_keys = ["loc", "scale"]
if self.out_keys is None:
self.out_keys = ["action"]
[docs]@dataclass
class ValueModelConfig(ModelConfig):
"""A class to configure a Value model.
Example:
>>> cfg = ValueModelConfig(network=MLPConfig(in_features=10, out_features=5, depth=2, num_cells=32))
>>> net = instantiate(cfg)
>>> y = net(torch.randn(1, 10))
>>> assert y.shape == (1, 5)
.. seealso:: :class:`torchrl.modules.ValueOperator`
"""
_target_: str = "torchrl.trainers.algorithms.configs.modules._make_value_model"
network: NetworkConfig = MISSING
def __post_init__(self) -> None:
"""Post-initialization hook for value model configurations."""
super().__post_init__()
def _make_tanh_normal_model(*args, **kwargs):
"""Helper function to create a TanhNormal model with ProbabilisticTensorDictSequential."""
from hydra.utils import instantiate
from tensordict.nn import (
ProbabilisticTensorDictModule,
ProbabilisticTensorDictSequential,
TensorDictModule,
)
from torchrl.modules import NormalParamExtractor, TanhNormal
# Extract parameters
network = kwargs.pop("network")
in_keys = list(kwargs.pop("in_keys", ["observation"]))
param_keys = list(kwargs.pop("param_keys", ["loc", "scale"]))
out_keys = list(kwargs.pop("out_keys", ["action"]))
extract_normal_params = kwargs.pop("extract_normal_params", True)
scale_mapping = kwargs.pop("scale_mapping", "biased_softplus_1.0")
scale_lb = kwargs.pop("scale_lb", 1e-4)
return_log_prob = kwargs.pop("return_log_prob", False)
eval_mode = kwargs.pop("eval_mode", False)
exploration_type = kwargs.pop("exploration_type", "RANDOM")
# Now instantiate the network
if hasattr(network, "_target_"):
network = instantiate(network)
elif callable(network) and hasattr(network, "func"): # partial function
network = network()
# Create the sequential
if extract_normal_params:
# Add NormalParamExtractor to split the output
network = torch.nn.Sequential(
network,
NormalParamExtractor(scale_mapping=scale_mapping, scale_lb=scale_lb),
)
module = TensorDictModule(network, in_keys=in_keys, out_keys=param_keys)
# Create ProbabilisticTensorDictModule
prob_module = ProbabilisticTensorDictModule(
in_keys=param_keys,
out_keys=out_keys,
distribution_class=TanhNormal,
return_log_prob=return_log_prob,
default_interaction_type=exploration_type,
**kwargs,
)
result = ProbabilisticTensorDictSequential(module, prob_module)
if eval_mode:
result.eval()
return result
def _make_value_model(*args, **kwargs):
"""Helper function to create a ValueOperator with the given network."""
from torchrl.modules import ValueOperator
network = kwargs.pop("network")
return ValueOperator(network, **kwargs)