Shortcuts

Source code for torchtune.modules.transforms.tokenizers._utils

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import json
from typing import Any, Dict, List, Optional, Protocol, Tuple

from torchtune.data._messages import Message
from torchtune.data._utils import truncate


[docs]class BaseTokenizer(Protocol): """ Abstract token encoding model that implements ``encode`` and ``decode`` methods. See :class:`~torchtune.modules.transforms.tokenizers.SentencePieceBaseTokenizer` and :class:`~torchtune.modules.transforms.tokenizers.TikTokenBaseTokenizer` for example implementations of this protocol. """
[docs] def encode(self, text: str, **kwargs: Dict[str, Any]) -> List[int]: """ Given a string, return the encoded list of token ids. Args: text (str): The text to encode. **kwargs (Dict[str, Any]): kwargs. Returns: List[int]: The encoded list of token ids. """ pass
[docs] def decode(self, token_ids: List[int], **kwargs: Dict[str, Any]) -> str: """ Given a list of token ids, return the decoded text, optionally including special tokens. Args: token_ids (List[int]): The list of token ids to decode. **kwargs (Dict[str, Any]): kwargs. Returns: str: The decoded text. """ pass
[docs]class ModelTokenizer(Protocol): """ Abstract tokenizer that implements model-specific special token logic in the ``tokenize_messages`` method. See :class:`~torchtune.models.llama3.Llama3Tokenizer` for an example implementation of this protocol. """ special_tokens: Dict[str, int] max_seq_len: Optional[int]
[docs] def tokenize_messages( self, messages: List[Message], **kwargs: Dict[str, Any] ) -> Tuple[List[int], List[bool]]: """ Given a list of messages, return a list of tokens and list of masks for the concatenated and formatted messages. Args: messages (List[Message]): The list of messages to tokenize. **kwargs (Dict[str, Any]): kwargs. Returns: Tuple[List[int], List[bool]]: The list of token ids and the list of masks. """ pass
[docs]def tokenize_messages_no_special_tokens( tokenizer: ModelTokenizer, messages: List[Message], *, bos_id: Optional[int] = None, eos_id: Optional[int] = None, ) -> Tuple[List[int], List[bool]]: r"""Tokenize a list of messages one at a time then concatenate them, returning a list of tokens and a list of masks. Does not add any special tokens except for BOS and EOS (if provided). This serves as a common starting point for model tokenizers that do not rely heavily on special tokens. Examples: >>> messages = [ ... Message(role="system", content="system message\n", masked=True), ... Message(role="user", content="user prompt\n", masked=True), ... Message(role="assistant", content="assistant response\n"), ... ] # tokenize_messages encodes messages separately and concats >>> tokens = tokenize_messages_no_special_tokens( ... tokenizer, ... messages, ... bos_id=tokenizer.bos_id, ... eos_id=tokenizer.eos_id, ... )[0] >>> print(tokens) [1, 1788, 2643, 13, 1792, 9508, 13, 465, 22137, 2933, 2] # Same result as encoding the full string in one go >>> print(tokenizer.encode(''.join([message.content for message in messages]))) [1, 1788, 2643, 13, 1792, 9508, 13, 465, 22137, 2933, 2] Args: tokenizer (ModelTokenizer): Tokenizer to encode messages with. messages (List[Message]): A list of messages, each containing role, content, and masked attributes. bos_id (Optional[int]): Beginning-of-sequence token id. If None, no BOS token will be added. Default None. eos_id (Optional[int]): End-of-sequence token id. If None, no EOS token will be added. Default None. Returns: Tuple[List[int], List[bool]]: The tokenized messages. Raises: RuntimeError: if any message in ``messages`` does not satisfy ``message['type'] == 'text'``. """ start_of_turn = True end_of_turn = False prev_ends_with_space = False max_seq_len = tokenizer.max_seq_len # We define this on ModelTokenizer tokenized_messages = [] mask = [] for message in messages: # If assistant message, this is the end of a turn end_of_turn = message.role == "assistant" # Prepend BOS on start of new turns if start_of_turn and bos_id is not None: tokenized_messages.append(bos_id) mask.append(message.masked) # We want to trim leading whitespace on the next message when # (a) it is a continuation of the turn (i.e. not the first message) # (b) the vocabulary explicitly encodes whitespace characters (checked inside # the base tokenizer's encode method), and # (c) the previous message did not end with a space trim_leading_whitespace = (not start_of_turn) and not prev_ends_with_space # Tokenize current message, append with masks tokens = [] for item in message.content: if item["type"] == "text": tokens = tokens + tokenizer.encode( item["content"].rstrip(" "), add_bos=False, add_eos=False, trim_leading_whitespace=trim_leading_whitespace, ) else: raise RuntimeError(f"Unsupported message content type: {item['type']}") prev_ends_with_space = item["content"].endswith(" ") tokenized_messages.extend(tokens) mask.extend([message.masked] * len(tokens)) # If assistant message, append EOS at end if end_of_turn: if eos_id is not None: tokenized_messages.append(eos_id) mask.append(message.masked) end_of_turn = False start_of_turn = True else: start_of_turn = False # Break out early if we reach max_seq_len if max_seq_len is not None and len(tokenized_messages) >= max_seq_len: break # Finally, truncate if necessary if max_seq_len is not None: tokenized_messages = truncate(tokenized_messages, max_seq_len, eos_id) mask = truncate( mask, max_seq_len, message.masked if eos_id is not None else None ) return tokenized_messages, mask
[docs]def parse_hf_tokenizer_json(tokenizer_json_path: str) -> Dict[str, int]: """ Parse the ``tokenizer.json`` file from a Hugging Face model to extract the special token str to id mapping. Args: tokenizer_json_path (str): Path to the ``tokenizer.json`` file. Returns: Dict[str, int]: The special token str to id mapping. """ with open(tokenizer_json_path, "r") as f: tokenizer_json = json.load(f) return {token["content"]: token["id"] for token in tokenizer_json["added_tokens"]}

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources