Shortcuts

Source code for torchrl.trainers.algorithms.configs.utils

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from dataclasses import dataclass

from torchrl.trainers.algorithms.configs.common import ConfigBase


[docs] @dataclass class AdamConfig(ConfigBase): """Configuration for Adam optimizer.""" lr: float = 1e-3 betas: tuple[float, float] = (0.9, 0.999) eps: float = 1e-4 weight_decay: float = 0.0 amsgrad: bool = False _target_: str = "torch.optim.Adam" _partial_: bool = True def __post_init__(self) -> None: """Post-initialization hook for Adam optimizer configurations."""
[docs] @dataclass class AdamWConfig(ConfigBase): """Configuration for AdamW optimizer.""" lr: float = 1e-3 betas: tuple[float, float] = (0.9, 0.999) eps: float = 1e-8 weight_decay: float = 1e-2 amsgrad: bool = False maximize: bool = False foreach: bool | None = None capturable: bool = False differentiable: bool = False fused: bool | None = None _target_: str = "torch.optim.AdamW" _partial_: bool = True def __post_init__(self) -> None: """Post-initialization hook for AdamW optimizer configurations."""
[docs] @dataclass class AdamaxConfig(ConfigBase): """Configuration for Adamax optimizer.""" lr: float = 2e-3 betas: tuple[float, float] = (0.9, 0.999) eps: float = 1e-8 weight_decay: float = 0.0 _target_: str = "torch.optim.Adamax" _partial_: bool = True def __post_init__(self) -> None: """Post-initialization hook for Adamax optimizer configurations."""
[docs] @dataclass class SGDConfig(ConfigBase): """Configuration for SGD optimizer.""" lr: float = 1e-3 momentum: float = 0.0 dampening: float = 0.0 weight_decay: float = 0.0 nesterov: bool = False maximize: bool = False foreach: bool | None = None differentiable: bool = False _target_: str = "torch.optim.SGD" _partial_: bool = True def __post_init__(self) -> None: """Post-initialization hook for SGD optimizer configurations."""
[docs] @dataclass class RMSpropConfig(ConfigBase): """Configuration for RMSprop optimizer.""" lr: float = 1e-2 alpha: float = 0.99 eps: float = 1e-8 weight_decay: float = 0.0 momentum: float = 0.0 centered: bool = False maximize: bool = False foreach: bool | None = None differentiable: bool = False _target_: str = "torch.optim.RMSprop" _partial_: bool = True def __post_init__(self) -> None: """Post-initialization hook for RMSprop optimizer configurations."""
[docs] @dataclass class AdagradConfig(ConfigBase): """Configuration for Adagrad optimizer.""" lr: float = 1e-2 lr_decay: float = 0.0 weight_decay: float = 0.0 initial_accumulator_value: float = 0.0 eps: float = 1e-10 maximize: bool = False foreach: bool | None = None differentiable: bool = False _target_: str = "torch.optim.Adagrad" _partial_: bool = True def __post_init__(self) -> None: """Post-initialization hook for Adagrad optimizer configurations."""
[docs] @dataclass class AdadeltaConfig(ConfigBase): """Configuration for Adadelta optimizer.""" lr: float = 1.0 rho: float = 0.9 eps: float = 1e-6 weight_decay: float = 0.0 foreach: bool | None = None maximize: bool = False differentiable: bool = False _target_: str = "torch.optim.Adadelta" _partial_: bool = True def __post_init__(self) -> None: """Post-initialization hook for Adadelta optimizer configurations."""
[docs] @dataclass class RpropConfig(ConfigBase): """Configuration for Rprop optimizer.""" lr: float = 1e-2 etas: tuple[float, float] = (0.5, 1.2) step_sizes: tuple[float, float] = (1e-6, 50.0) foreach: bool | None = None maximize: bool = False differentiable: bool = False _target_: str = "torch.optim.Rprop" _partial_: bool = True def __post_init__(self) -> None: """Post-initialization hook for Rprop optimizer configurations."""
[docs] @dataclass class ASGDConfig(ConfigBase): """Configuration for ASGD optimizer.""" lr: float = 1e-2 lambd: float = 1e-4 alpha: float = 0.75 t0: float = 1e6 weight_decay: float = 0.0 foreach: bool | None = None maximize: bool = False differentiable: bool = False _target_: str = "torch.optim.ASGD" _partial_: bool = True def __post_init__(self) -> None: """Post-initialization hook for ASGD optimizer configurations."""
[docs] @dataclass class LBFGSConfig(ConfigBase): """Configuration for LBFGS optimizer.""" lr: float = 1.0 max_iter: int = 20 max_eval: int | None = None tolerance_grad: float = 1e-7 tolerance_change: float = 1e-9 history_size: int = 100 line_search_fn: str | None = None _target_: str = "torch.optim.LBFGS" _partial_: bool = True def __post_init__(self) -> None: """Post-initialization hook for LBFGS optimizer configurations."""
[docs] @dataclass class RAdamConfig(ConfigBase): """Configuration for RAdam optimizer.""" lr: float = 1e-3 betas: tuple[float, float] = (0.9, 0.999) eps: float = 1e-8 weight_decay: float = 0.0 _target_: str = "torch.optim.RAdam" _partial_: bool = True def __post_init__(self) -> None: """Post-initialization hook for RAdam optimizer configurations."""
[docs] @dataclass class NAdamConfig(ConfigBase): """Configuration for NAdam optimizer.""" lr: float = 2e-3 betas: tuple[float, float] = (0.9, 0.999) eps: float = 1e-8 weight_decay: float = 0.0 momentum_decay: float = 4e-3 foreach: bool | None = None _target_: str = "torch.optim.NAdam" _partial_: bool = True def __post_init__(self) -> None: """Post-initialization hook for NAdam optimizer configurations."""
[docs] @dataclass class SparseAdamConfig(ConfigBase): """Configuration for SparseAdam optimizer.""" lr: float = 1e-3 betas: tuple[float, float] = (0.9, 0.999) eps: float = 1e-8 _target_: str = "torch.optim.SparseAdam" _partial_: bool = True def __post_init__(self) -> None: """Post-initialization hook for SparseAdam optimizer configurations."""
[docs] @dataclass class LionConfig(ConfigBase): """Configuration for Lion optimizer.""" lr: float = 1e-4 betas: tuple[float, float] = (0.9, 0.99) weight_decay: float = 0.0 _target_: str = "torch.optim.Lion" _partial_: bool = True def __post_init__(self) -> None: """Post-initialization hook for Lion optimizer configurations."""

Docs

Lorem ipsum dolor sit amet, consectetur

View Docs

Tutorials

Lorem ipsum dolor sit amet, consectetur

View Tutorials

Resources

Lorem ipsum dolor sit amet, consectetur

View Resources