Shortcuts

Source code for torchrl.envs.llm.datasets.countdown

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import random
from collections.abc import Callable
from typing import Any, Literal, TYPE_CHECKING

import torch
from tensordict import TensorDict
from torch.utils.data import DataLoader, IterableDataset
from torchrl.envs import StepCounter
from torchrl.envs.llm.chat import DatasetChatEnv
from torchrl.envs.llm.reward.countdown import CountdownRewardParser

if TYPE_CHECKING:
    import transformers


class _CountdownProblemGenerator(IterableDataset):
    """Infinite procedural generator for Countdown problems.

    Each problem picks ``num_count`` numbers from [1, max_number] and
    generates a target that is reachable from those numbers using
    ``+``, ``-``, ``*``, ``/``.
    """

    def __init__(
        self,
        num_count: int = 4,
        max_number: int = 100,
        max_target: int = 1000,
        seed: int | None = None,
    ):
        self.num_count = num_count
        self.max_number = max_number
        self.max_target = max_target
        self.rng = random.Random(seed)

    def __iter__(self):
        return self

    def __next__(self) -> dict[str, Any]:
        numbers = [self.rng.randint(1, self.max_number) for _ in range(self.num_count)]
        target = self._make_target(numbers)
        query = (
            f"Using the numbers {numbers}, create an arithmetic expression that "
            f"equals {target}. You may use each number at most once. "
            f"Only use +, -, *, / and parentheses."
        )
        answer = f"target={target}, numbers={','.join(str(n) for n in numbers)}"
        return {"query": query, "answer": answer}

    def _make_target(self, numbers: list[int]) -> int:
        """Generate a reachable target by randomly combining numbers."""
        ops = [
            lambda a, b: a + b,
            lambda a, b: a - b,
            lambda a, b: a * b,
        ]
        pool = list(numbers)
        self.rng.shuffle(pool)
        result = pool[0]
        for n in pool[1:]:
            op = self.rng.choice(ops)
            result = op(result, n)
        result = abs(result)
        if result == 0:
            result = sum(numbers)
        if result > self.max_target:
            result = sum(numbers)
        return result


def _collate_fn(batch):
    return torch.stack([TensorDict.from_dict(b) for b in batch])


[docs] class CountdownEnv(DatasetChatEnv): """Countdown numbers-game environment for LLM post-training. Given a set of source numbers and a target, the model must construct an arithmetic expression that evaluates to the target using each source number at most once. Problems are generated procedurally (no external dataset required), making this environment ideal for quick experimentation and debugging of RL training loops. Keyword Args: num_count (int): How many source numbers per problem. Defaults to ``4``. max_number (int): Maximum value for each source number. Defaults to ``100``. max_target (int): Ceiling for the generated target. Defaults to ``1000``. shuffle (bool): Ignored (procedural generation is always random). num_envs (int): Number of parallel environments. Defaults to ``1``. repeats (int | None): Repeats per sample for MC estimation. batch_size_dl (int): Dataloader batch size. Defaults to ``1``. seed (int | None): Random seed for reproducibility. group_repeats (bool): Group repeated samples. Defaults to ``False``. tokenizer: Tokenizer for text processing. device: Device for computation. template_kwargs: Extra kwargs for ``apply_chat_template``. apply_template (bool): Apply chat template. Defaults to ``False``. compute_reward (bool): Compute rewards. Defaults to ``True``. collate_fn: Custom collate function. max_steps (int): Max steps per episode. Defaults to ``1``. input_mode: ``"history"``, ``"text"``, or ``"tokens"``. Examples: >>> import transformers >>> from torchrl.envs.llm.datasets.countdown import CountdownEnv >>> tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B") >>> env = CountdownEnv(tokenizer=tokenizer, apply_template=True, seed=42) >>> r = env.reset() >>> assert "history" in r """ SYSTEM_PROMPT = ( "A conversation between User and Assistant. The user gives a set of " "numbers and a target. The Assistant must find an arithmetic expression " "using each given number at most once that equals the target.\n" "The reasoning process and answer are enclosed within <think></think> " "and <answer></answer> tags, respectively.\n" "The answer should contain ONLY the arithmetic expression (e.g. " "(25 + 3) * 4)." ) def __init__( self, *, num_count: int = 4, max_number: int = 100, max_target: int = 1000, shuffle: bool = True, num_envs: int = 1, repeats: int | None = None, batch_size_dl: int = 1, seed: int | None = None, group_repeats: bool = False, tokenizer: transformers.AutoTokenizer | None = None, # noqa device: torch.device | None = None, template_kwargs: dict[str, Any] | None = None, apply_template: bool | None = False, compute_reward: bool = True, collate_fn: Callable | None = None, max_steps: int = 1, input_mode: Literal["history", "text", "tokens"] = "history", ): if collate_fn is None: collate_fn = _collate_fn self._num_count = num_count self._max_number = max_number self._max_target = max_target self._seed = seed batch_size = (num_envs,) dataloader = DataLoader( # noqa: TOR401 _CountdownProblemGenerator( num_count=num_count, max_number=max_number, max_target=max_target, seed=seed, ), batch_size=batch_size_dl, collate_fn=collate_fn, ) self._from_dataloader( self, dataloader=dataloader, repeats=repeats, device=device, group_repeats=group_repeats, batch_size=batch_size, tokenizer=tokenizer, template_kwargs=template_kwargs, input_mode=input_mode, ) if max_steps: self.append_transform(StepCounter(max_steps=max_steps)) if compute_reward: self.append_transform(CountdownRewardParser(tokenizer=tokenizer))

Docs

Lorem ipsum dolor sit amet, consectetur

View Docs

Tutorials

Lorem ipsum dolor sit amet, consectetur

View Tutorials

Resources

Lorem ipsum dolor sit amet, consectetur

View Resources