Source code for torchtune.data._messages
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from pathlib import Path
from typing import Any, Dict, List, Literal, Mapping, Optional, Union
from torchtune.data._utils import format_content_with_images, load_image
from torchtune.modules.transforms import Transform
Role = Literal[
"system", # Origin is system prompt
"user", # Origin is user
"assistant", # Origin is the model output
"ipython", # Origin is return from a tool call
]
[docs]class Message:
"""
This class represents individual messages in a fine-tuning dataset. It supports
text-only content, text with interleaved images, and tool calls. The
:class:`~torchtune.modules.transforms.tokenizers.ModelTokenizer` will tokenize
the content of the message using ``tokenize_messages`` and attach the appropriate
special tokens based on the flags set in this class.
Args:
role (Role): role of the message writer. Can be "system" for system prompts,
"user" for human prompts, "assistant" for model responses, or "ipython"
for tool call returns.
content (Union[str, List[Dict[str, Any]]]): content of the message. If it is text only content,
you can pass in a string. If it is multimodal content, pass in a list of dictionaries formatted
as follows::
[
{"type": "image", "content": <PIL.Image.Image>},
{"type": "text", "content": "What is in this image?"},
]
masked (bool): whether the message is masked in the sample. If True, do not use
in loss calculation. Default: False
ipython (bool): whether the message is a tool call. Default: False
eot (bool): whether the message corresponds to the end of a turn, where control is handed over
to the assistant from the user or the user from the assistant. Default: True. Should be true
in most cases except for:
- For multiple consecutive assistant messages (i.e., tool calls
by assistant), only the last assistant message will have ``eot=True``
- All ipython messages (tool call returns) should set ``eot=False``.
Note:
Message class expects any image content to be in
`PIL Image format <https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.Image>`_.
"""
def __init__(
self,
role: Role,
content: Union[str, List[Dict[str, Any]]],
masked: bool = False,
ipython: bool = False,
eot: bool = True,
):
self.role = role
self.content = self._convert_to_list_of_dict(content)
self.masked = masked
self.ipython = ipython
self.eot = eot
self._validate_message()
def _convert_to_list_of_dict(self, content) -> List[Dict[str, Any]]:
"""User is currently allowed to pass in a string for text-only content.
This ensures that the content is formatted as a list of dictionaries."""
if isinstance(content, str):
return [{"type": "text", "content": content}]
assert isinstance(
content, list
), f"content must be of type List[Dict[str, Any]], got {content}"
return content
[docs] @classmethod
def from_dict(cls, d: dict) -> "Message":
"""
Construct a Message from a dictionary.
Args:
d (dict): dictionary containing the fields of the Message.
Returns:
Message: constructed Message.
"""
return cls(
role=d["role"],
content=d["content"],
masked=d.get("masked", False),
ipython=d.get("ipython", False),
eot=d.get("eot", True),
)
[docs] def get_media(self) -> List["PIL.Image.Image"]:
"""
Returns media content of the message.
"""
return [
content["content"] for content in self.content if content["type"] == "image"
]
@property
def contains_media(self) -> bool:
"""
Returns whether the message contains media.
"""
return any(content["type"] == "image" for content in self.content)
@property
def text_content(self) -> str:
"""
Returns text-only content of the message.
"""
return "".join(
content["content"] for content in self.content if content["type"] == "text"
)
def _validate_message(self) -> None:
if self.ipython and self.contains_media:
raise ValueError(
f"Media tokens in tool calls are not supported. Both are set in message: {self.text_content}"
)
if self.ipython and self.role != "assistant":
raise ValueError(
f"Only assistant messages can be tool calls. Found role {self.role} in message: {self.text_content}"
)
def __repr__(self) -> str:
content_only = [content["content"] for content in self.content]
return f"Message(role='{self.role}', content={content_only!r})"
[docs]class InputOutputToMessages(Transform):
"""
Message transform class that converts a single sample with "input" and "output" fields,
(or equivalent fields specified in column_map) to user and assistant messages,
respectively. This is useful for datasets that have two columns, one containing
the user prompt string and the other containing the model response string::
| input | output |
|-----------------|------------------|
| "user prompt" | "model response" |
Args:
train_on_input (bool): Whether the model is trained on the user prompt or not.
Default is False.
column_map (Optional[Dict[str, str]]): a mapping to change the expected "input"
and "output" column names to the actual column names in the dataset. Keys should
be "input" and "output" and values should be the actual column names. Default is None,
keeping the default "input" and "output" column names.
new_system_prompt (Optional[str]): if specified, prepend a system message. This can
serve as instructions to guide the model response. Default is None.
image_dir (Optional[Path]): path to the directory containing the images that is prepended to all image
paths in the dataset. For example, if ``image_dir="/home/user/dataset/"` and the sample image path
was ``"images/1.jpg"``, the final image path that will be loaded is ``"/home/user/dataset/images/1.jpg"``.
If None, assume images are available in current working directory or are located
on a remote url. For text-only, leave as None. Default is None.
Raises:
ValueError:
If ``column_map`` is provided and ``input`` not in ``column_map``, or
``output`` not in ``column_map``, **or**
if ``image_dir`` is provided but ``image`` not in ``column_map``.
"""
def __init__(
self,
train_on_input: bool = False,
column_map: Optional[Dict[str, str]] = None,
new_system_prompt: Optional[str] = None,
image_dir: Optional[Path] = None,
):
self.train_on_input = train_on_input
self.new_system_prompt = new_system_prompt
self.column_map = column_map
if self.column_map is not None:
if "input" not in self.column_map:
raise ValueError(
f"Expected a key of 'input' in column_map but found {self.column_map.keys()}."
)
if "output" not in self.column_map:
raise ValueError(
f"Expected a key of 'output' in column_map but found {self.column_map.keys()}."
)
else:
self.column_map = {"input": "input", "output": "output", "image": "image"}
# Ensure that if a user seems to want to construct a multimodal transform, they provide
# a proper column_mapping
if "image" not in self.column_map.keys() and image_dir is not None:
raise ValueError(
f"image_dir is specified as {image_dir} but 'image' is not in column_map. "
"Please specify an 'image' key in column_map."
)
self.image_dir = image_dir
def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
is_multimodal = "image" in sample or (
"image" in self.column_map and self.column_map["image"] in sample
)
if is_multimodal:
image_path = sample[self.column_map["image"]]
if isinstance(image_path, str):
# Convert image_path to Path obj
image_path = Path(image_path)
# If image_dir is not None, prepend image_dir to image_path
if self.image_dir is not None:
image_path = self.image_dir / image_path
# Load if not loaded
pil_image = load_image(image_path)
else:
pil_image = image_path
content = [
{"type": "image", "content": pil_image},
{"type": "text", "content": sample[self.column_map["input"]]},
]
else:
content = [{"type": "text", "content": sample[self.column_map["input"]]}]
output_content = [
{"type": "text", "content": sample[self.column_map["output"]]}
]
messages = [
Message(
role="user",
content=content,
masked=not self.train_on_input,
eot=True,
),
Message(
role="assistant",
content=output_content,
masked=False,
eot=True,
),
]
if self.new_system_prompt is not None:
messages = [
Message(
role="system", content=self.new_system_prompt, masked=True, eot=True
)
] + messages
return {"messages": messages}
[docs]class ChosenRejectedToMessages(Transform):
"""
Transform for converting a single sample from datasets with "chosen" and "rejected" columns
containing conversations to a list of chosen and rejected messages. For example::
| chosen | rejected |
|----------------------------------------|----------------------------------------|
| [{"role": "user", "content": Q1}, | [{"role": "user", "content": Q1}, |
| {"role": "assistant", "content": A1}] | {"role": "assistant", "content": A2}] |
will be converted to:
.. code-block:: python
chosen = [
Message(role="user", content="Q1"),
Message(role="assistant", content="A1"),
]
rejected = [
Message(role="user", content="Q1"),
Message(role="assistant", content="A2"),
]
A single sample typically consists of a single optional system prompt and one or multiple
turns of user and assistant messages.
Args:
train_on_input (bool): Whether the model is trained on the user prompt or not.
Default is False.
column_map (Optional[Dict[str, str]]): a mapping to change the expected
"chosen" and "rejected" column names to the actual column names in the dataset.
Keys should be "chosen" and "rejected" and values should be the actual column names.
Default is None, keeping the default column names.
new_system_prompt (Optional[str]): if specified, prepend a system message. This can
serve as instructions to guide the model response. Setting this will OVERRIDE any system
messages already present in the dataset. Default is None.
Raises:
ValueError: If ``column_map`` is provided and ``chosen`` not in ``column_map``, or
``rejected`` not in ``column_map``.
"""
def __init__(
self,
train_on_input: bool = False,
column_map: Optional[Dict[str, str]] = None,
new_system_prompt: Optional[str] = None,
):
self.train_on_input = train_on_input
self.new_system_prompt = new_system_prompt
if column_map:
if "chosen" not in column_map:
raise ValueError(
f"Expected a key of 'chosen' in column_map but found {column_map.keys()}."
)
if "rejected" not in column_map:
raise ValueError(
f"Expected a key of 'rejected' in column_map but found {column_map.keys()}."
)
self._column_map = column_map
else:
self._column_map = {"chosen": "chosen", "rejected": "rejected"}
def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
chosen_messages = []
for message in sample[self._column_map["chosen"]]:
if message["role"] == "system" and self.new_system_prompt is not None:
continue
message["masked"] = (message["role"] != "assistant") and (
not self.train_on_input
)
chosen_messages.append(Message.from_dict(message))
rejected_messages = []
for message in sample[self._column_map["rejected"]]:
if message["role"] == "system" and self.new_system_prompt is not None:
continue
message["masked"] = (message["role"] != "assistant") and (
not self.train_on_input
)
rejected_messages.append(Message.from_dict(message))
if self.new_system_prompt is not None:
chosen_messages = [
Message(
role="system", content=self.new_system_prompt, masked=True, eot=True
)
] + chosen_messages
rejected_messages = [
Message(
role="system", content=self.new_system_prompt, masked=True, eot=True
)
] + rejected_messages
return {"chosen": chosen_messages, "rejected": rejected_messages}
[docs]class OpenAIToMessages(Transform):
"""
Convert a single chat sample adhering to the `OpenAI chat completion <https://platform.openai.com/docs/api-reference/chat>`_
JSON structure to torchtune's :class:`~torchtune.data.Message` structure. This supports both
text and image messages.
A single sample typically consists of a single optional system prompt and one or multiple
turns of user and assistant messages.
For example::
{
"messages": [
{
"role": <system|user|assistant>,
"content": [
{
"type": "text",
"text": "What'\''s in this image?",
},
{
"type": "image_url",
"image_url": {
"url": <url>,
},
},
},
...
]
}
:class:`~torchtune.data.Message` follows::
[
{
"role": <system|user|assistant>,
"content": [
{
"type": "text",
"content": "What'\''s in this image?",
},
{
"type": "image",
"content": <PIL.Image.Image>,
},
],
},
...
]
Args:
train_on_input (bool): whether the prompt should remain unmasked. Default: False
column_map (Optional[Dict[str, str]]): a mapping from the expected columns ("messages")
to the new column names in the dataset. Key should be "messages" and value should be
the new column name. If None, keep the default "messages".
Default is None.
new_system_prompt (Optional[str]): if specified, prepend a system message. This can
serve as instructions to guide the model response. Setting this will OVERRIDE any system
messages already present in the dataset. Default is None.
Raises:
ValueError: If ``column_map`` is provided and ``messages`` not in ``column_map``.
"""
def __init__(
self,
train_on_input: bool = False,
column_map: Optional[Dict[str, str]] = None,
new_system_prompt: Optional[str] = None,
):
self.train_on_input = train_on_input
self.new_system_prompt = new_system_prompt
if column_map:
if "messages" not in column_map:
raise ValueError(
f"Expected a key of 'messages' in column_map but found {column_map.keys()}."
)
self._column_map = column_map
else:
self._column_map = {"messages": "messages"}
def _convert_from_openai_content(
self, content: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""Converts a list of content dicts from the OpenAI format to the torchtune format."""
converted_content = []
for content_dict in content:
if content_dict["type"] == "text":
converted_content.append(
{"type": "text", "content": content_dict["text"]}
)
elif content_dict["type"] == "image_url":
converted_content.append(
{
"type": "image",
"content": load_image(content_dict["image_url"]["url"]),
}
)
return converted_content
def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
"""
Return a list of Message objects from the provided sample dict.
Args:
sample (Mapping[str, Any]): a single data sample with "messages" field pointing
to a list of dict messages.
Returns:
List[Message]: A list of messages with "role" and "content" fields.
"""
updated_messages = []
if self.new_system_prompt is not None:
updated_messages.append(
Message(
role="system", content=self.new_system_prompt, masked=True, eot=True
)
)
for message in sample[self._column_map["messages"]]:
if message["role"] == "system" and self.new_system_prompt is not None:
continue
masked = (message["role"] != "assistant") and (not self.train_on_input)
if isinstance(message["content"], list):
content = self._convert_from_openai_content(message["content"])
elif isinstance(message["content"], str):
content = message["content"]
updated_messages.append(
Message(
role=message["role"],
content=content,
masked=masked,
),
)
return {"messages": updated_messages}
[docs]def validate_messages(
messages: List[Message],
) -> None:
"""
Given a list of messages, ensure that messages form a valid
back-and-forth conversation. An error will be raised if:
- There is a system message that's not the first message
- There are two consecutive user messages
- An assistant message comes before the first user message
- The message is empty
- Messages are shorter than length of 2 (min. one user-assistant turn)
Args:
messages (List[Message]): the messages to validate.
Raises:
ValueError: If the messages are invalid.
"""
if len(messages) < 2:
raise ValueError(
f"Messages must be at least length 2, but got {len(messages)} messages"
)
last_turn = "assistant"
for i, message in enumerate(messages):
if message.role == "assistant" and last_turn != "user":
raise ValueError(
f"Assistant message before expected user message at index {i} in messages"
)
if message.role == "user" and last_turn == "user":
raise ValueError(
f"Two consecutive user messages at index {i} and {i - 1} in messages"
)
if message.role == "system" and i > 0:
raise ValueError(
f"System message at index {i} in messages, but system messages must come first"
)
last_turn = message.role
[docs]class AlpacaToMessages(Transform):
"""
Message transform class for Alpaca-style datasets with "instruction", "input", and "output"
(or equivalent fields specified in column_map) columns. User messages are formed from the
instruction + input columns and assistant messages are formed from the output column. Prompt
templating is conditional on the presence of the "input" column, and thus is handled directly
in this transform class instead of a dedicated :class:`~torchtune.data.PromptTemplate` class
due to this custom logic.
Args:
train_on_input (bool): Whether the model is trained on the user prompt or not.
Default is True.
column_map (Optional[Dict[str, str]]): a mapping to change the expected "instruction", "input",
and "output" column names to the actual column names in the dataset. Default is None,
keeping the default column names.
"""
def __init__(
self, train_on_input: bool = True, column_map: Optional[Dict[str, str]] = None
):
self.train_on_input = train_on_input
self.column_map = column_map
self.template = {
"prompt_input": (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
),
"prompt_no_input": (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response:\n"
),
}
def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
column_map = self.column_map or {}
key_input = column_map.get("input", "input")
key_instruction = column_map.get("instruction", "instruction")
key_output = column_map.get("output", "output")
if key_input in sample and sample[key_input]:
prompt = self.template["prompt_input"].format(
instruction=sample[key_instruction], input=sample[key_input]
)
else:
prompt = self.template["prompt_no_input"].format(
instruction=sample[key_instruction]
)
messages = [
Message(
role="user",
content=prompt,
masked=not self.train_on_input,
eot=True,
),
Message(
role="assistant",
content=sample[key_output],
masked=False,
eot=True,
),
]
return {"messages": messages}