Shortcuts

Source code for torchtune.rlhf.loss.ppo

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional, Tuple

import torch
import torch.nn as nn
from torchtune import rlhf


[docs]class PPOLoss(nn.Module): """ Proximal Policy Optimization (PPO) Loss module. This implementation uses the following references: https://arxiv.org/abs/1707.06347 eqn. 7 https://github.com/vwxyzjn/lm-human-preference-details/blob/ccc19538e817e98a60d3253242ac15e2a562cb49/lm_human_preference_details/train_policy_accelerate.py#L719 https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/ppo2/model.py#L68-L75 Args: epsilon (float): clipping range for PPO update. value_clip_range (float): clipping range for value function update. value_coeff (float): coefficient for the value function loss contribution. """ def __init__( self, epsilon: float = 0.1, value_clip_range: float = 0.2, value_coeff: float = 0.1, ): super().__init__() self.epsilon = epsilon self.value_clip_range = value_clip_range self.value_coeff = value_coeff
[docs] def forward( self, pi_old_logprobs: torch.Tensor, pi_logprobs: torch.Tensor, advantages: torch.Tensor, phi_old_values: torch.Tensor, phi_values: torch.Tensor, returns: torch.Tensor, padding_masks: Optional[torch.Tensor] = None, value_padding_masks: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Forward pass of the PPO loss module. Args: pi_old_logprobs (torch.Tensor): Log probabilities of the old policy. pi_logprobs (torch.Tensor): Log probabilities of the current policy. advantages (torch.Tensor): Advantage values. phi_old_values (torch.Tensor): Value predictions of the old value function. phi_values (torch.Tensor): Value predictions of the current value function. returns (torch.Tensor): Return values. padding_masks (Optional[torch.Tensor]): Padding token masks of the same shape as ``pi_logprobs``, where True indicates the corresponding loss values should participage in policy loss calculation. value_padding_masks (Optional[torch.Tensor]): Padding token masks of the same shape as ``pi_logprobs``, where True indicates the corresponding loss values should participage in value loss calculation. Returns: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: A tuple of five tensors: - loss: The total PPO loss. - policy_loss: The policy function loss. - value_loss: The value function loss. - ratios: The ratio between the current and old policy probabilities. - clipfrac: The fraction of ratios that were clipped. """ ratios = torch.exp(pi_logprobs - pi_old_logprobs) clipped_ratios = torch.clamp(ratios, 1.0 - self.epsilon, 1.0 + self.epsilon) policy_losses_clipped = -advantages * clipped_ratios policy_losses_unclipped = -advantages * ratios clipfrac = (policy_losses_clipped > policy_losses_unclipped).float() clipfrac = ( clipfrac.mean() if padding_masks is None else rlhf.masked_mean(clipfrac, padding_masks) ) policy_loss = torch.maximum(policy_losses_clipped, policy_losses_unclipped) policy_loss = ( policy_loss.mean() if padding_masks is None else rlhf.masked_mean(policy_loss, padding_masks) ) values_clipped = torch.clamp( phi_values, phi_old_values - self.value_clip_range, phi_old_values + self.value_clip_range, ) value_loss = torch.maximum( (phi_values - returns) ** 2, (values_clipped - returns) ** 2 ) value_loss = ( 0.5 * value_loss.mean() if value_padding_masks is None else 0.5 * rlhf.masked_mean(value_loss, value_padding_masks) ) loss = policy_loss + (value_loss * self.value_coeff) return ( loss, policy_loss.detach(), value_loss.detach(), ratios.mean().detach(), clipfrac.detach(), )

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources