Shortcuts

Source code for torchrl.trainers.algorithms.configs.utils

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from dataclasses import dataclass

from torchrl.trainers.algorithms.configs.common import ConfigBase


[docs]@dataclass class AdamConfig(ConfigBase): """Configuration for Adam optimizer.""" lr: float = 1e-3 betas: tuple[float, float] = (0.9, 0.999) eps: float = 1e-4 weight_decay: float = 0.0 amsgrad: bool = False _target_: str = "torch.optim.Adam" _partial_: bool = True def __post_init__(self) -> None: """Post-initialization hook for Adam optimizer configurations."""
[docs]@dataclass class AdamWConfig(ConfigBase): """Configuration for AdamW optimizer.""" lr: float = 1e-3 betas: tuple[float, float] = (0.9, 0.999) eps: float = 1e-8 weight_decay: float = 1e-2 amsgrad: bool = False maximize: bool = False foreach: bool | None = None capturable: bool = False differentiable: bool = False fused: bool | None = None _target_: str = "torch.optim.AdamW" _partial_: bool = True def __post_init__(self) -> None: """Post-initialization hook for AdamW optimizer configurations."""
[docs]@dataclass class AdamaxConfig(ConfigBase): """Configuration for Adamax optimizer.""" lr: float = 2e-3 betas: tuple[float, float] = (0.9, 0.999) eps: float = 1e-8 weight_decay: float = 0.0 _target_: str = "torch.optim.Adamax" _partial_: bool = True def __post_init__(self) -> None: """Post-initialization hook for Adamax optimizer configurations."""
[docs]@dataclass class SGDConfig(ConfigBase): """Configuration for SGD optimizer.""" lr: float = 1e-3 momentum: float = 0.0 dampening: float = 0.0 weight_decay: float = 0.0 nesterov: bool = False maximize: bool = False foreach: bool | None = None differentiable: bool = False _target_: str = "torch.optim.SGD" _partial_: bool = True def __post_init__(self) -> None: """Post-initialization hook for SGD optimizer configurations."""
[docs]@dataclass class RMSpropConfig(ConfigBase): """Configuration for RMSprop optimizer.""" lr: float = 1e-2 alpha: float = 0.99 eps: float = 1e-8 weight_decay: float = 0.0 momentum: float = 0.0 centered: bool = False maximize: bool = False foreach: bool | None = None differentiable: bool = False _target_: str = "torch.optim.RMSprop" _partial_: bool = True def __post_init__(self) -> None: """Post-initialization hook for RMSprop optimizer configurations."""
[docs]@dataclass class AdagradConfig(ConfigBase): """Configuration for Adagrad optimizer.""" lr: float = 1e-2 lr_decay: float = 0.0 weight_decay: float = 0.0 initial_accumulator_value: float = 0.0 eps: float = 1e-10 maximize: bool = False foreach: bool | None = None differentiable: bool = False _target_: str = "torch.optim.Adagrad" _partial_: bool = True def __post_init__(self) -> None: """Post-initialization hook for Adagrad optimizer configurations."""
[docs]@dataclass class AdadeltaConfig(ConfigBase): """Configuration for Adadelta optimizer.""" lr: float = 1.0 rho: float = 0.9 eps: float = 1e-6 weight_decay: float = 0.0 foreach: bool | None = None maximize: bool = False differentiable: bool = False _target_: str = "torch.optim.Adadelta" _partial_: bool = True def __post_init__(self) -> None: """Post-initialization hook for Adadelta optimizer configurations."""
[docs]@dataclass class RpropConfig(ConfigBase): """Configuration for Rprop optimizer.""" lr: float = 1e-2 etas: tuple[float, float] = (0.5, 1.2) step_sizes: tuple[float, float] = (1e-6, 50.0) foreach: bool | None = None maximize: bool = False differentiable: bool = False _target_: str = "torch.optim.Rprop" _partial_: bool = True def __post_init__(self) -> None: """Post-initialization hook for Rprop optimizer configurations."""
[docs]@dataclass class ASGDConfig(ConfigBase): """Configuration for ASGD optimizer.""" lr: float = 1e-2 lambd: float = 1e-4 alpha: float = 0.75 t0: float = 1e6 weight_decay: float = 0.0 foreach: bool | None = None maximize: bool = False differentiable: bool = False _target_: str = "torch.optim.ASGD" _partial_: bool = True def __post_init__(self) -> None: """Post-initialization hook for ASGD optimizer configurations."""
[docs]@dataclass class LBFGSConfig(ConfigBase): """Configuration for LBFGS optimizer.""" lr: float = 1.0 max_iter: int = 20 max_eval: int | None = None tolerance_grad: float = 1e-7 tolerance_change: float = 1e-9 history_size: int = 100 line_search_fn: str | None = None _target_: str = "torch.optim.LBFGS" _partial_: bool = True def __post_init__(self) -> None: """Post-initialization hook for LBFGS optimizer configurations."""
[docs]@dataclass class RAdamConfig(ConfigBase): """Configuration for RAdam optimizer.""" lr: float = 1e-3 betas: tuple[float, float] = (0.9, 0.999) eps: float = 1e-8 weight_decay: float = 0.0 _target_: str = "torch.optim.RAdam" _partial_: bool = True def __post_init__(self) -> None: """Post-initialization hook for RAdam optimizer configurations."""
[docs]@dataclass class NAdamConfig(ConfigBase): """Configuration for NAdam optimizer.""" lr: float = 2e-3 betas: tuple[float, float] = (0.9, 0.999) eps: float = 1e-8 weight_decay: float = 0.0 momentum_decay: float = 4e-3 foreach: bool | None = None _target_: str = "torch.optim.NAdam" _partial_: bool = True def __post_init__(self) -> None: """Post-initialization hook for NAdam optimizer configurations."""
[docs]@dataclass class SparseAdamConfig(ConfigBase): """Configuration for SparseAdam optimizer.""" lr: float = 1e-3 betas: tuple[float, float] = (0.9, 0.999) eps: float = 1e-8 _target_: str = "torch.optim.SparseAdam" _partial_: bool = True def __post_init__(self) -> None: """Post-initialization hook for SparseAdam optimizer configurations."""
[docs]@dataclass class LionConfig(ConfigBase): """Configuration for Lion optimizer.""" lr: float = 1e-4 betas: tuple[float, float] = (0.9, 0.99) weight_decay: float = 0.0 _target_: str = "torch.optim.Lion" _partial_: bool = True def __post_init__(self) -> None: """Post-initialization hook for Lion optimizer configurations."""

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources