Source code for torchtune.datasets._alpaca
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from functools import partial
from typing import Any, Callable, Dict, Optional, Union
from torchtune.data._messages import AlpacaToMessages
from torchtune.datasets._packed import PackedDataset
from torchtune.datasets._sft import SFTDataset
from torchtune.modules.transforms.tokenizers import ModelTokenizer
[docs]def alpaca_dataset(
tokenizer: ModelTokenizer,
*,
source: str = "tatsu-lab/alpaca",
column_map: Optional[Dict[str, str]] = None,
train_on_input: bool = True,
packed: bool = False,
filter_fn: Optional[Callable] = None,
split: str = "train",
**load_dataset_kwargs: Dict[str, Any],
) -> Union[SFTDataset, PackedDataset]:
"""
Support for family of Alpaca-style datasets from Hugging Face Datasets using
the `data input format <https://huggingface.co/datasets/tatsu-lab/alpaca#data-instances>`_
and `prompt template <https://github.com/tatsu-lab/stanford_alpaca/blob/main/train.py#L31>`_
from the original alpaca codebase, where ``instruction``, ``input``, and ``output``
are fields from the dataset. This template is automatically applied independent
of any prompt template configured in the tokenizer.
Masking of the prompt during training is controlled by the ``train_on_input`` flag, which is
set to ``True`` by `default <https://github.com/tloen/alpaca-lora/blob/main/finetune.py#L49>`_
- If ``train_on_input`` is True, the prompt is used during training and
contributes to the loss.
- If ``train_on_input`` is False, the prompt is masked out (tokens replaced with -100)
Args:
tokenizer (ModelTokenizer): Tokenizer used by the model that implements the ``tokenize_messages`` method.
source (str): path to dataset repository on Hugging Face. For local datasets,
define source as the data file type (e.g. "json", "csv", "text") and pass
in the filepath in ``data_files``. See `Hugging Face's
<https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset.path>`_
``load_dataset`` for more details. Default is ``tatsu-lab/alpaca``.
column_map (Optional[Dict[str, str]]): a mapping from the expected columns in the message transform
:class:`~torchtune.data.AlpacaToMessages` to the new column names in the dataset. Keys should be
"instruction", "input", and "output" and values should be the actual column names. If None, uses
the default column names ``"instruction``, ``"input"``, and ``"output"`` in ``tatsu-lab/alpaca``.
train_on_input (bool): Whether the model is trained on the prompt or not. Default is True.
packed (bool): Whether or not to pack the dataset to ``max_seq_len`` prior to training. Default is False.
filter_fn (Optional[Callable]): callable used to filter the dataset prior to any pre-processing. See
the Hugging Face `docs <https://huggingface.co/docs/datasets/v2.20.0/process#select-and-filter>`_ for more
details.
split (str): ``split`` argument for ``datasets.load_dataset``. You can use this argument to load a subset
of a given split, e.g. ``split="train[:10%]"``. Default is "train".
**load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``. See Hugging
Face's `API ref <https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset>`_
for more details.
Returns:
Union[SFTDataset, PackedDataset]: dataset configured with source data and transform
Raises:
ValueError: If ``packed`` is True and ``max_seq_len`` is not set on the tokenizer.
Example:
>>> alpaca_ds = alpaca_dataset(tokenizer=tokenizer)
>>> for batch in Dataloader(alpaca_ds, batch_size=8):
>>> print(f"Batch size: {len(batch)}")
>>> Batch size: 8
"""
message_transform = AlpacaToMessages(
train_on_input=train_on_input, column_map=column_map
)
ds = SFTDataset(
source=source,
message_transform=message_transform,
model_transform=tokenizer,
filter_fn=filter_fn,
split=split,
**load_dataset_kwargs,
)
if packed:
if tokenizer.max_seq_len is None:
raise ValueError(
"PackedDataset requires a max_seq_len to be set on the tokenizer."
)
return PackedDataset(ds, max_seq_len=tokenizer.max_seq_len)
return ds
alpaca_cleaned_dataset = partial(alpaca_dataset, source="yahma/alpaca-cleaned")
alpaca_cleaned_dataset.__doc__ = """
Builder for a variant of Alpaca-style datasets with the cleaned version of the
original Alpaca dataset, `yahma/alpaca-cleaned <https://huggingface.co/datasets/yahma/alpaca-cleaned>`_.
See the dataset page and :func:`~torchtune.datasets.alpaca_dataset` for more details.
"""