Source code for torchrl.trainers.algorithms.configs.utils
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
from dataclasses import dataclass
from torchrl.trainers.algorithms.configs.common import ConfigBase
[docs]@dataclass
class AdamConfig(ConfigBase):
"""Configuration for Adam optimizer."""
lr: float = 1e-3
betas: tuple[float, float] = (0.9, 0.999)
eps: float = 1e-4
weight_decay: float = 0.0
amsgrad: bool = False
_target_: str = "torch.optim.Adam"
_partial_: bool = True
def __post_init__(self) -> None:
"""Post-initialization hook for Adam optimizer configurations."""
[docs]@dataclass
class AdamWConfig(ConfigBase):
"""Configuration for AdamW optimizer."""
lr: float = 1e-3
betas: tuple[float, float] = (0.9, 0.999)
eps: float = 1e-8
weight_decay: float = 1e-2
amsgrad: bool = False
maximize: bool = False
foreach: bool | None = None
capturable: bool = False
differentiable: bool = False
fused: bool | None = None
_target_: str = "torch.optim.AdamW"
_partial_: bool = True
def __post_init__(self) -> None:
"""Post-initialization hook for AdamW optimizer configurations."""
[docs]@dataclass
class AdamaxConfig(ConfigBase):
"""Configuration for Adamax optimizer."""
lr: float = 2e-3
betas: tuple[float, float] = (0.9, 0.999)
eps: float = 1e-8
weight_decay: float = 0.0
_target_: str = "torch.optim.Adamax"
_partial_: bool = True
def __post_init__(self) -> None:
"""Post-initialization hook for Adamax optimizer configurations."""
[docs]@dataclass
class SGDConfig(ConfigBase):
"""Configuration for SGD optimizer."""
lr: float = 1e-3
momentum: float = 0.0
dampening: float = 0.0
weight_decay: float = 0.0
nesterov: bool = False
maximize: bool = False
foreach: bool | None = None
differentiable: bool = False
_target_: str = "torch.optim.SGD"
_partial_: bool = True
def __post_init__(self) -> None:
"""Post-initialization hook for SGD optimizer configurations."""
[docs]@dataclass
class RMSpropConfig(ConfigBase):
"""Configuration for RMSprop optimizer."""
lr: float = 1e-2
alpha: float = 0.99
eps: float = 1e-8
weight_decay: float = 0.0
momentum: float = 0.0
centered: bool = False
maximize: bool = False
foreach: bool | None = None
differentiable: bool = False
_target_: str = "torch.optim.RMSprop"
_partial_: bool = True
def __post_init__(self) -> None:
"""Post-initialization hook for RMSprop optimizer configurations."""
[docs]@dataclass
class AdagradConfig(ConfigBase):
"""Configuration for Adagrad optimizer."""
lr: float = 1e-2
lr_decay: float = 0.0
weight_decay: float = 0.0
initial_accumulator_value: float = 0.0
eps: float = 1e-10
maximize: bool = False
foreach: bool | None = None
differentiable: bool = False
_target_: str = "torch.optim.Adagrad"
_partial_: bool = True
def __post_init__(self) -> None:
"""Post-initialization hook for Adagrad optimizer configurations."""
[docs]@dataclass
class AdadeltaConfig(ConfigBase):
"""Configuration for Adadelta optimizer."""
lr: float = 1.0
rho: float = 0.9
eps: float = 1e-6
weight_decay: float = 0.0
foreach: bool | None = None
maximize: bool = False
differentiable: bool = False
_target_: str = "torch.optim.Adadelta"
_partial_: bool = True
def __post_init__(self) -> None:
"""Post-initialization hook for Adadelta optimizer configurations."""
[docs]@dataclass
class RpropConfig(ConfigBase):
"""Configuration for Rprop optimizer."""
lr: float = 1e-2
etas: tuple[float, float] = (0.5, 1.2)
step_sizes: tuple[float, float] = (1e-6, 50.0)
foreach: bool | None = None
maximize: bool = False
differentiable: bool = False
_target_: str = "torch.optim.Rprop"
_partial_: bool = True
def __post_init__(self) -> None:
"""Post-initialization hook for Rprop optimizer configurations."""
[docs]@dataclass
class ASGDConfig(ConfigBase):
"""Configuration for ASGD optimizer."""
lr: float = 1e-2
lambd: float = 1e-4
alpha: float = 0.75
t0: float = 1e6
weight_decay: float = 0.0
foreach: bool | None = None
maximize: bool = False
differentiable: bool = False
_target_: str = "torch.optim.ASGD"
_partial_: bool = True
def __post_init__(self) -> None:
"""Post-initialization hook for ASGD optimizer configurations."""
[docs]@dataclass
class LBFGSConfig(ConfigBase):
"""Configuration for LBFGS optimizer."""
lr: float = 1.0
max_iter: int = 20
max_eval: int | None = None
tolerance_grad: float = 1e-7
tolerance_change: float = 1e-9
history_size: int = 100
line_search_fn: str | None = None
_target_: str = "torch.optim.LBFGS"
_partial_: bool = True
def __post_init__(self) -> None:
"""Post-initialization hook for LBFGS optimizer configurations."""
[docs]@dataclass
class RAdamConfig(ConfigBase):
"""Configuration for RAdam optimizer."""
lr: float = 1e-3
betas: tuple[float, float] = (0.9, 0.999)
eps: float = 1e-8
weight_decay: float = 0.0
_target_: str = "torch.optim.RAdam"
_partial_: bool = True
def __post_init__(self) -> None:
"""Post-initialization hook for RAdam optimizer configurations."""
[docs]@dataclass
class NAdamConfig(ConfigBase):
"""Configuration for NAdam optimizer."""
lr: float = 2e-3
betas: tuple[float, float] = (0.9, 0.999)
eps: float = 1e-8
weight_decay: float = 0.0
momentum_decay: float = 4e-3
foreach: bool | None = None
_target_: str = "torch.optim.NAdam"
_partial_: bool = True
def __post_init__(self) -> None:
"""Post-initialization hook for NAdam optimizer configurations."""
[docs]@dataclass
class SparseAdamConfig(ConfigBase):
"""Configuration for SparseAdam optimizer."""
lr: float = 1e-3
betas: tuple[float, float] = (0.9, 0.999)
eps: float = 1e-8
_target_: str = "torch.optim.SparseAdam"
_partial_: bool = True
def __post_init__(self) -> None:
"""Post-initialization hook for SparseAdam optimizer configurations."""
[docs]@dataclass
class LionConfig(ConfigBase):
"""Configuration for Lion optimizer."""
lr: float = 1e-4
betas: tuple[float, float] = (0.9, 0.99)
weight_decay: float = 0.0
_target_: str = "torch.optim.Lion"
_partial_: bool = True
def __post_init__(self) -> None:
"""Post-initialization hook for Lion optimizer configurations."""