Shortcuts

Source code for torchrl.envs.llm.datasets.math

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

from collections.abc import Callable
from typing import Any, Literal, TYPE_CHECKING

import torch
from tensordict import TensorDict
from torchrl.envs import StepCounter
from torchrl.envs.llm.chat import DatasetChatEnv
from torchrl.envs.llm.reward.math import MATHRewardParser

if TYPE_CHECKING:
    import transformers


def _collate_fn(batch):
    batch = torch.stack([TensorDict.from_dict(_batch) for _batch in batch])
    batch.rename_key_("problem", "query")
    batch.rename_key_("solution", "answer")
    return batch


[docs] class MATHEnv(DatasetChatEnv): r"""MATH (competition mathematics) dataset environment. Uses the ``DigitalLearningGmbH/MATH-lighteval`` dataset on Hugging Face (a drop-in replacement for the original ``hendrycks/competition_math``). Answers are in LaTeX ``\boxed{}`` format. When ``math-verify`` is installed the reward parser uses symbolic equivalence checking; otherwise it falls back to normalised string comparison. Keyword Args: dataset (str, optional): HuggingFace dataset name. Defaults to ``"DigitalLearningGmbH/MATH-lighteval"``. shuffle (bool, optional): Shuffle the dataset. Defaults to ``True``. num_envs (int, optional): Number of parallel envs. Defaults to ``1``. repeats (int | None, optional): Repeats per sample for MC estimation. batch_size_dl (int, optional): Dataloader batch size. Defaults to ``1``. seed (int | None, optional): Random seed. group_repeats (bool, optional): Group repeated samples. Defaults to ``False``. tokenizer: Tokenizer for text processing. device: Device for computation. template_kwargs: Extra kwargs for ``apply_chat_template``. apply_template (bool): Apply chat template. Defaults to ``False``. compute_reward (bool): Compute rewards. Defaults to ``True``. collate_fn: Custom collate function. max_steps (int): Max steps per episode. Defaults to ``1``. input_mode: ``"history"``, ``"text"``, or ``"tokens"``. ray_backend (bool): Use Ray backend for data loading. dataloader_actor_name (str): Ray actor name for data loading. Examples: >>> import transformers >>> from torchrl.envs.llm.datasets.math import MATHEnv >>> tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B") >>> env = MATHEnv(tokenizer=tokenizer, apply_template=True) >>> r = env.reset() >>> assert "history" in r """ SYSTEM_PROMPT = ( "A conversation between User and Assistant. The user asks a math problem, " "and the Assistant solves it.\n" "The assistant first thinks about the reasoning process in the mind and " "then provides the user with the answer.\n" "The reasoning process and answer are enclosed within <think></think> and " "<answer></answer> tags, respectively, i.e.,\n" "<think>reasoning process here</think> <answer>answer here</answer>.\n" "The answer should be a mathematical expression (use LaTeX if needed)." ) def __init__( self, *, dataset: str = "DigitalLearningGmbH/MATH-lighteval", shuffle: bool = True, num_envs: int = 1, repeats: int | None = None, batch_size_dl: int = 1, seed: int | None = None, group_repeats: bool = False, tokenizer: transformers.AutoTokenizer | None = None, # noqa device: torch.device | None = None, template_kwargs: dict[str, Any] | None = None, apply_template: bool | None = False, compute_reward: bool = True, collate_fn: Callable | None = None, max_steps: int = 1, input_mode: Literal["history", "text", "tokens"] = "history", ray_backend: bool = False, dataloader_actor_name: str | None = None, ): if ray_backend and dataloader_actor_name is None: dataloader_actor_name = "math_dataloader" if collate_fn is None: collate_fn = _collate_fn super().__init__( dataset=dataset, shuffle=shuffle, num_envs=num_envs, repeats=repeats, batch_size_dl=batch_size_dl, seed=seed, group_repeats=group_repeats, tokenizer=tokenizer, device=device, template_kwargs=template_kwargs, apply_template=apply_template, collate_fn=collate_fn, input_mode=input_mode, ray_backend=ray_backend, dataloader_actor_name=dataloader_actor_name, ) if max_steps: self.append_transform(StepCounter(max_steps=max_steps)) if compute_reward: self.append_transform(MATHRewardParser(tokenizer=tokenizer))

Docs

Lorem ipsum dolor sit amet, consectetur

View Docs

Tutorials

Lorem ipsum dolor sit amet, consectetur

View Tutorials

Resources

Lorem ipsum dolor sit amet, consectetur

View Resources