Source code for torchrl.trainers.algorithms.configs.objectives
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
from torchrl.objectives.ppo import ClipPPOLoss, KLPENPPOLoss, PPOLoss
from torchrl.trainers.algorithms.configs.common import ConfigBase
[docs]@dataclass
class LossConfig(ConfigBase):
    """A class to configure a loss.
    Args:
        loss_type: The type of loss to use.
    """
    _partial_: bool = False
    def __post_init__(self) -> None:
        """Post-initialization hook for loss configurations."""
[docs]@dataclass
class PPOLossConfig(LossConfig):
    """A class to configure a PPO loss.
    Args:
        loss_type: The type of loss to use.
    """
    actor_network: Any = None
    critic_network: Any = None
    loss_type: str = "clip"
    entropy_bonus: bool = True
    samples_mc_entropy: int = 1
    entropy_coeff: float | None = None
    log_explained_variance: bool = True
    critic_coeff: float = 0.25
    loss_critic_type: str = "smooth_l1"
    normalize_advantage: bool = True
    normalize_advantage_exclude_dims: tuple = ()
    gamma: float | None = None
    separate_losses: bool = False
    advantage_key: str | None = None
    value_target_key: str | None = None
    value_key: str | None = None
    functional: bool = True
    actor: Any = None
    critic: Any = None
    reduction: str | None = None
    clip_value: float | None = None
    device: Any = None
    _target_: str = "torchrl.trainers.algorithms.configs.objectives._make_ppo_loss"
    def __post_init__(self) -> None:
        """Post-initialization hook for PPO loss configurations."""
        super().__post_init__()
def _make_ppo_loss(*args, **kwargs) -> PPOLoss:
    loss_type = kwargs.pop("loss_type", "clip")
    if loss_type == "clip":
        return ClipPPOLoss(*args, **kwargs)
    elif loss_type == "kl":
        return KLPENPPOLoss(*args, **kwargs)
    elif loss_type == "ppo":
        return PPOLoss(*args, **kwargs)
    else:
        raise ValueError(f"Invalid loss type: {loss_type}")