Shortcuts

Source code for torchtune.modules.tokenizers._tiktoken

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Dict, Iterator, List

from tiktoken import Encoding
from tiktoken.load import load_tiktoken_bpe
from torchtune.modules.tokenizers._utils import BaseTokenizer

# Constants controlling encode logic
MAX_ENCODE_CHARS = 400_000
MAX_NO_WHITESPACE_CHARS = 25_000


[docs]class TikTokenBaseTokenizer(BaseTokenizer): """ A lightweight wrapper around tiktoken Encoding. This class additionally handles breaking up the input text into substrings of a max length and splitting up long repetitions to improve encode speed. Args: path (str): Path to pretrained tokenizer checkpoint file. name (str): Name of the tokenizer (used by tiktoken for identification). pattern (str): Regex pattern used to split input text into chunks before passing to byte-pair encoding. bos_id (int): beginning-of-sequence token id. This can be present or absent in ``special_tokens``. eos_id (int): end-of-sequence token id. This can be present or absent in ``special_tokens``. special_tokens (Dict[str, int]): Mapping of special tokens to their ids. Examples: >>> tokenizer = TikTokenBaseTokenizer("/path/to/tt_model") >>> tokenized_text = tokenizer.encode("Hello world!", add_bos=True, add_eos=True) >>> print(tokenized_text) [1, 31587, 29644, 102, 2] """ def __init__( self, path: str, name: str, pattern: str, bos_id: int, eos_id: int, special_tokens: Dict[str, int], ): mergeable_ranks = load_tiktoken_bpe(path) self.tt_model = Encoding( name=name, pat_str=pattern, mergeable_ranks=mergeable_ranks, special_tokens=special_tokens, ) # Vocab size without special tokens self.base_vocab_size = len(mergeable_ranks) # Vocab size with special tokens self.vocab_size = self.tt_model.n_vocab self.bos_id = bos_id self.eos_id = eos_id def _split_long_repetitions( self, s: str, max_consecutive_slice_len: int ) -> Iterator[str]: """ Split the string `s` so that each substring contains no more than `max_consecutive_slice_len` consecutive whitespaces or consecutive non-whitespaces """ current_slice_len = 0 current_slice_is_space = s[0].isspace() if len(s) > 0 else False slice_start = 0 for i in range(len(s)): is_now_space = s[i].isspace() if current_slice_is_space ^ is_now_space: current_slice_len = 1 current_slice_is_space = is_now_space else: current_slice_len += 1 if current_slice_len > max_consecutive_slice_len: yield s[slice_start:i] slice_start = i current_slice_len = 1 yield s[slice_start:]
[docs] def encode( self, text: str, add_bos: bool = True, add_eos: bool = True, ) -> List[int]: """ Encode a string into a list of token ids. Assumes that the string contains no special tokens. Args: text (str): The string to encode. add_bos (bool): Whether to add the tokenizer's bos_id to the encoded string. Default True. add_eos (bool): Whether to add the tokenizer's eos_id to the encoded string. Default True. Returns: List[int]: The list of token ids. """ substrs: List[str] = [] tokens = [] if not text: return [] for i in range(0, len(text), MAX_ENCODE_CHARS): substr = text[i : i + MAX_ENCODE_CHARS] # See https://github.com/openai/tiktoken/issues/195 sliced_substr = self._split_long_repetitions( substr, MAX_NO_WHITESPACE_CHARS ) substrs.extend(sliced_substr) for substr in substrs: # allowed_special and disallowed_special are used by tiktoken to define # how special tokens are encoded. Our setting here is to encode any # special token as regular text and prevent tiktoken from raising errors. # This means we should only call encode on strings not containing special tokens. tokens.extend( self.tt_model.encode( substr, allowed_special=set(), disallowed_special=(), ) ) if add_bos: tokens = [self.bos_id] + tokens if add_eos: tokens = tokens + [self.eos_id] return tokens
[docs] def decode( self, token_ids: List[int], truncate_at_eos: bool = True, ) -> str: """ Decode a list of token ids into a string. Args: token_ids (List[int]): The list of token ids. truncate_at_eos (bool): Whether to truncate the string at the end of sequence token. Default is True. Returns: str: The decoded string. """ if truncate_at_eos: try: k = token_ids.index(self.eos_id) except ValueError: k = None if k: token_ids = token_ids[:k] return self.tt_model.decode(token_ids)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources