Source code for torchtune.models.gemma._tokenizer
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import List, Optional, Tuple
from torchtune.data import Message
from torchtune.modules.tokenizers import (
ModelTokenizer,
SentencePieceBaseTokenizer,
tokenize_messages_no_special_tokens,
)
WHITESPACE_CHARS = [" ", "\n", "\t", "\r", "\v"]
[docs]class GemmaTokenizer(ModelTokenizer):
"""
Gemma's implementation of the SentencePiece tokenizer
Args:
path (str): Path to pretrained tokenizer file.
Examples:
>>> tokenizer = GemmaTokenizer("/path/to/spm_model")
>>> tokenized_text = tokenizer.encode("Hello world!", add_bos=True, add_eos=True)
>>> print(tokenized_text)
[1, 31587, 29644, 102, 2]
"""
def __init__(
self,
path: str,
):
self._spm_model = SentencePieceBaseTokenizer(path)
# Original tokenizer has no pad_id, which causes indexing errors when batch training
self._spm_model.pad_id = 0
# During generation, stop when eos_id is encountered
self.stop_tokens = [self.eos_id]
@property
def eos_id(self):
return self._spm_model.eos_id
@property
def bos_id(self):
return self._spm_model.bos_id
@property
def pad_id(self):
return self._spm_model.pad_id
@property
def vocab_size(self):
return self._spm_model.vocab_size
def encode(
self,
text: str,
add_bos: bool = True,
add_eos: bool = True,
trim_leading_whitespace: bool = False,
) -> List[int]:
return self._spm_model.encode(
text,
add_bos=add_bos,
add_eos=add_eos,
trim_leading_whitespace=trim_leading_whitespace,
)
def decode(
self,
token_ids: List[int],
) -> str:
return self._spm_model.decode(token_ids)
[docs] def tokenize_messages(
self, messages: List[Message], max_seq_len: Optional[int] = None
) -> Tuple[List[int], List[bool]]:
r"""Tokenize a list of messages one at a time then concatenate them,
returning a list of tokens and a list of masks.
Example:
>>> tokenizer = GemmaTokenizer(tokenizer_path)
>>> messages = [
Message(role="system", content="system message\n", masked=True),
Message(role="user", content="user prompt\n", masked=True),
Message(role="assistant", content="assistant response\n"),
]
# tokenize_messages encodes messages separately and concats
>>> tokenizer.tokenize_messages(messages, max_seq_len)[0]
[1, 1788, 2643, 13, 1792, 9508, 13, 465, 22137, 2933, 2]
# Same result as encoding the full string in one go
>>> tokenizer.encode(''.join([message.content for message in messages]))
[1, 1788, 2643, 13, 1792, 9508, 13, 465, 22137, 2933, 2]
Args:
messages (List[Message]): A list of messages, each containing role, content,
and masked attributes.
max_seq_len (Optional[int]): A max sequence length to truncate tokens to.
Default: None
Returns:
Tuple[List[int], List[bool]]: The tokenized messages
"""
return tokenize_messages_no_special_tokens(
tokenizer=self,
messages=messages,
bos_id=self.bos_id,
eos_id=self.eos_id,
max_seq_len=max_seq_len,
)