Source code for torchrl.data.llm.topk
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
from collections import defaultdict, deque
from typing import Any
import torch
from tensordict import NestedKey, TensorDictBase
from torchrl._utils import logger as torchrl_logger
from torchrl.envs.transforms import Transform
[docs]class TopKRewardSelector(Transform):
"""A replay-buffer transform that selects the top-k rewards for each prompt.
Args:
total_dialog_turns (int): Number of dialog turns to keep in memory for the top-k selection.
topk_size (int): Number of top-k rewards to select. Must be smaller than or equal to total_dialog_turns.
prompt_key (NestedKey): Key to the prompt in the tensordict. Defaults to "text".
rewards_key (NestedKey): Key to the rewards in the tensordict. Defaults to ("next", "reward").
done_key (NestedKey): Key to the done state in the tensordict. Defaults to ("next", "done").
verbose (bool): Whether to print verbose information. Defaults to `False`.
Example:
>>> from torchrl.data import ReplayBuffer, LazyStackStorage, SamplerWithoutReplacement
>>> from tensordict import TensorDict, lazy_stack
>>> import torch
>>> from torchrl.data.llm.topk import TopKRewardSelector
>>> # Create a replay buffer with 50 items, a sampler that samples without replacement, and a batch size of 5
>>> rb = ReplayBuffer(
... storage=LazyStackStorage(50),
... sampler=SamplerWithoutReplacement,
... batch_size=5,
... )
>>> # Create a tensordict with 50 items, each with 10 dialog turns
>>> td = lazy_stack(
... [
... TensorDict(
... {
... ("next", "done"): torch.full((1, 1), True),
... # Reward for i+5 tokens
... ("next", "reward"): torch.full((i + 5, 1), i),
... # total of 10 dialogs per prompt
... "text": f"Prompt {i // 5}",
... }
... )
... for i in range(50)
... ]
... )
>>> # Create a top-k reward selector with 5 dialog turns and a top-k size of 3
>>> topk = TopKRewardSelector(total_dialog_turns=5, topk_size=3)
>>> rb.append_transform(topk)
>>> for _td in td.chunk(25):
... rb.extend(_td)
>>> # Only wrote top3 of 50 items in 10 groups of 5
>>> assert rb.write_count == 30
>>> assert len(rb) == 30
>>> r3 = rb[:3].get(("next", "reward"), as_padded_tensor=True).squeeze()
>>> # 0 and 1 are missing because they're not part of the top-k
>>> assert (
... r3 == torch.tensor(
... [
... [4, 4, 4, 4, 4, 4, 4, 4, 4],
... [3, 3, 3, 3, 3, 3, 3, 3, 0],
... [2, 2, 2, 2, 2, 2, 2, 0, 0],
... ]
... )
... ).all()
"""
def __init__(
self,
total_dialog_turns: int,
topk_size: int,
prompt_key: NestedKey = "text",
rewards_key: NestedKey = ("next", "reward"),
done_key: NestedKey = ("next", "done"),
verbose: bool = True,
):
super().__init__()
self.in_keys = [prompt_key, rewards_key, done_key]
self.prompt_key = prompt_key
self.rewards_key = rewards_key
self.done_key = done_key
self.queues = defaultdict(lambda: deque(maxlen=total_dialog_turns))
self.total_dialog_turns = total_dialog_turns
self.topk_size = topk_size
if topk_size > total_dialog_turns:
raise ValueError(
f"topk_size must be smaller than or equal to total_dialog_turns, got {topk_size=} and {total_dialog_turns=}"
)
self.verbose = verbose
def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
# Tensordict can be any number of dims, but it must contain entire trajectories
if tensordict.ndim == 1:
# Check how many done states we have
num_done = tensordict[self.done_key].sum()
if num_done > 1:
done_idx = tensordict[self.done_key].nonzero(as_tuple=True)[0] + 1
splits = torch.cat([done_idx.new_zeros((1,)), done_idx], dim=0).diff()
tensordicts = tensordict.split(splits)
tensordicts = [self._inv_call(td) for td in tensordicts]
tensordicts = [td for td in tensordicts if td is not None]
return torch.cat(tensordicts) if tensordicts else None
# Then we have a single trajectory
if not tensordict[-1][self.done_key].all():
raise RuntimeError("Expected the trajectory to be done.")
prompt = tensordict[0][self.prompt_key]
if not isinstance(prompt, str):
raise TypeError(f"Expected a string as prompt, got {type(prompt)=}")
self.queues[prompt].append(tensordict)
if len(self.queues[prompt]) == self.total_dialog_turns:
if self.verbose:
torchrl_logger.info(f"Getting top-k rewards for {prompt=}")
# Cat is the most robust way to combine the trajs
tds = torch.cat(list(self.queues[prompt]), -1)
# Collect rewards
reward = tds.get(self.rewards_key, as_nested_tensor=True)
reward = self._aggregate_rewards(reward)
# Check if all rewards are equal
if (reward == reward[0]).all():
# If all rewards are equal, we can't select top-k
if self.verbose:
torchrl_logger.warning(
f"All rewards are equal ({reward.unique()=})"
)
return
# Filter out rewards below median
median_reward = reward.median(dim=-1, keepdim=True)[0]
mask = reward > median_reward
filtered_reward = reward[mask]
filtered_indices = mask.nonzero(as_tuple=True)[0]
# Get top-k from filtered rewards
topk_reward = filtered_reward.topk(
k=min(self.topk_size, len(filtered_indices)), dim=-1
)
if not topk_reward.indices.numel():
if self.verbose:
torchrl_logger.warning(
f"No top-{self.topk_size} rewards found ({reward=})"
)
return
# Map back to original indices
selected_indices = filtered_indices[topk_reward.indices]
tds = tds[selected_indices]
if self.verbose:
torchrl_logger.info(
f"Selected top-{self.topk_size} rewards, with reward {topk_reward.values=}"
)
return tds
return
elif tensordict.ndim > 2:
# keep the time dim at the end
tensordict = tensordict.flatten(0, -2)
trajs = tensordict.unbind(-1)
# Iterate over the trajectories
result = []
for traj in trajs:
td_out = self._inv_call(traj)
if td_out is None:
continue
result.append(td_out)
if result:
return torch.cat(result, -1)
return
def _aggregate_rewards(self, reward: torch.Tensor) -> torch.Tensor:
"""Aggregate the rewards across the dialog turns.
`reward` is expected to be a nested tensor.
The default implementation is to take the mean of the rewards across the dialog turns.
"""
# reward = reward.to_padded_tensor(padding=0.0)
if reward.ndim < 2 or reward.ndim > 3:
raise ValueError(
f"Expected reward to be a 2D or 3D tensor, got {reward.ndim}D tensor"
)
return reward.mean(dim=-2).squeeze(-1)