Source code for torchtune.datasets._slimorca
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from torchtune.data import Llama2ChatFormat, sharegpt_to_llama2_messages
from torchtune.datasets._chat import ChatDataset
from torchtune.modules import Tokenizer
[docs]def slimorca_dataset(
tokenizer: Tokenizer, max_seq_len: int = 1024, train_on_input: bool = False
) -> ChatDataset:
"""
PyTorch Representation of the SlimOrca Dataset
https://huggingface.co/datasets/Open-Orca/SlimOrca-Dedup
from Hugging Face.
The data is formatted to adhere to Llama2 Chat Format.
This format is required if the base model is Llama2 Chat Model.
The base Llama2 Model doesn't prescribe a particular format.
The returned data is a tuple of input token id list and label token id
list. If `max_seq_len` keyword argument is provided, the returned
input token id list is ensured (by truncation if necessary) to be within
that length.
Data input format: https://huggingface.co/datasets/Open-Orca/SlimOrca-Dedup#dataset-format
Args:
tokenizer (Tokenizer): Tokenizer used to encode data. Tokenize must implement an `encode` and `decode` method.
max_seq_len (int): Maximum number of tokens in the returned input and label token id lists.
This value needs to be at least 4 though it is generally set to max sequence length accepted by the model.
Default is 1024.
train_on_input (bool): Whether the model is trained on the prompt or not. Default is False.
Raises:
ValueError: If `max_seq_len` is less than 4.
Returns:
ChatDataset: dataset configured with SlimOrca source data and LLaMA2 chat template
Example:
>>> ds = slimorca_dataset(tokenizer=tokenizer, max_seq_len=10)
>>> for input, label in ds:
>>> print(input)
>>> print(label)
>>>
>>> Sample Output:
>>> [1, 351, 82, 391, 221, 220, 193, 12, 471, ..., 2]
>>> [-100, -100, -100, -100, -100, -100, -100, -100, 471, ..., 2]
"""
if max_seq_len < 4:
# Input token needs to have 1 bos, 1 eos,
# and 1 token from prompt, 1 from label
raise ValueError("max_seq_len must be at least 4")
return ChatDataset(
tokenizer=tokenizer,
source="Open-Orca/SlimOrca-Dedup",
convert_to_messages=sharegpt_to_llama2_messages,
chat_format=Llama2ChatFormat,
max_seq_len=max_seq_len,
train_on_input=train_on_input,
split="train",
)