# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
from collections.abc import Sequence
from typing import Any, TYPE_CHECKING
import torch
from tensordict import NonTensorData, NonTensorStack, TensorDictBase
from tensordict.nn import dispatch
from tensordict.utils import _zip_strict, NestedKey
from torch import Tensor
from torchrl._utils import _replace_last
from torchrl.data.tensor_specs import Bounded, Composite, TensorSpec
from torchrl.envs import Transform, UnaryTransform
from torchrl.envs.transforms.utils import _set_missing_tolerance
if TYPE_CHECKING:
import transformers
[docs]
class Tokenizer(UnaryTransform):
r"""Applies a tokenization operation on the specified inputs.
Args:
in_keys (sequence of NestedKey): the keys of inputs to the tokenization operation.
out_keys (sequence of NestedKey): the keys of the outputs of the tokenization operation.
in_keys_inv (sequence of NestedKey, optional): the keys of inputs to the tokenization operation during inverse call.
out_keys_inv (sequence of NestedKey, optional): the keys of the outputs of the tokenization operation during inverse call.
Keyword Args:
tokenizer (transformers.PretrainedTokenizerBase or str, optional): the tokenizer to use. If ``None``,
"bert-base-uncased" will be used by default. If a string is provided, it should be the name of a
pre-trained tokenizer.
use_raw_nontensor (bool, optional): if ``False``, data is extracted from
:class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack` inputs before the tokenization
function is called on them. If ``True``, the raw :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack`
inputs are given directly to the tokenization function, which must support those inputs. Default is ``False``.
additional_tokens (List[str], optional): list of additional tokens to add to the tokenizer's vocabulary.
.. note:: This transform can be used both to transform output strings into tokens and to transform back tokenized
actions or states into strings. If the environment has a string state-spec, the transformed version will have
a tokenized state-spec. If it is a string action spec, it will result in a tokenized action spec.
"""
def __init__(
self,
in_keys: Sequence[NestedKey] | None = None,
out_keys: Sequence[NestedKey] | None = None,
in_keys_inv: Sequence[NestedKey] | None = None,
out_keys_inv: Sequence[NestedKey] | None = None,
*,
tokenizer: transformers.PretrainedTokenizerBase = None, # noqa: F821
use_raw_nontensor: bool = False,
additional_tokens: list[str] | None = None,
skip_special_tokens: bool = True,
add_special_tokens: bool = False,
padding: bool = True,
max_length: int | None = None,
return_attention_mask: bool = True,
missing_tolerance: bool = True,
call_before_reset: bool = False,
):
if tokenizer is None:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
elif isinstance(tokenizer, str):
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(tokenizer)
self.tokenizer = tokenizer
self.add_special_tokens = add_special_tokens
self.skip_special_tokens = skip_special_tokens
self.padding = padding
self.max_length = max_length
self.return_attention_mask = return_attention_mask
self.call_before_reset = call_before_reset
if additional_tokens:
self.tokenizer.add_tokens(additional_tokens)
super().__init__(
in_keys=in_keys,
out_keys=out_keys,
in_keys_inv=in_keys_inv,
out_keys_inv=out_keys_inv,
fn=self.call_tokenizer_fn,
inv_fn=self.call_tokenizer_inv_fn,
use_raw_nontensor=use_raw_nontensor,
)
self._missing_tolerance = missing_tolerance
@property
def device(self):
if "_device" in self.__dict__:
return self._device
parent = self.parent
if parent is None:
return None
device = parent.device
self._device = device
return device
def _call(self, next_tensordict: TensorDictBase) -> TensorDictBase:
# Specialized for attention mask
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
value = next_tensordict.get(in_key, default=None)
if value is not None:
observation = self._apply_transform(value)
if self.return_attention_mask:
observation, attention_mask = observation
next_tensordict.set(
_replace_last(out_key, "attention_mask"),
attention_mask,
)
next_tensordict.set(
out_key,
observation,
)
elif (
self.missing_tolerance
and self.return_attention_mask
and out_key in next_tensordict.keys(True)
):
attention_key = _replace_last(out_key, "attention_mask")
if attention_key not in next_tensordict:
next_tensordict[attention_key] = torch.ones_like(
next_tensordict.get(out_key)
)
elif not self.missing_tolerance:
raise KeyError(
f"{self}: '{in_key}' not found in tensordict {next_tensordict}"
)
return next_tensordict
[docs]
@dispatch(source="in_keys", dest="out_keys")
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
data = tensordict.get(in_key, None)
if data is not None:
data = self._apply_transform(data)
if self.return_attention_mask:
data, attention_mask = data
tensordict.set(
_replace_last(out_key, "attention_mask"),
attention_mask,
)
tensordict.set(out_key, data)
elif not self.missing_tolerance:
raise KeyError(f"'{in_key}' not found in tensordict {tensordict}")
return tensordict
def _reset_env_preprocess(self, tensordict: TensorDictBase) -> TensorDictBase:
if self.call_before_reset:
with _set_missing_tolerance(self, True):
tensordict = self._call(tensordict)
return tensordict
def _reset(
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
) -> TensorDictBase:
if self.call_before_reset:
return tensordict_reset
return super()._reset(tensordict, tensordict_reset)
def call_tokenizer_fn(self, value: str | list[str]):
device = self.device
kwargs = {"add_special_tokens": self.add_special_tokens}
if self.max_length is not None:
kwargs["padding"] = "max_length"
kwargs["max_length"] = self.max_length
if isinstance(value, str):
out = self.tokenizer.encode(value, return_tensors="pt", **kwargs)[0]
# TODO: incorporate attention mask
if self.return_attention_mask:
attention_mask = torch.ones_like(out, dtype=torch.int64)
else:
kwargs["padding"] = (
self.padding if self.max_length is None else "max_length"
)
kwargs["return_attention_mask"] = self.return_attention_mask
# kwargs["return_token_type_ids"] = False
out = self.tokenizer.batch_encode_plus(value, return_tensors="pt", **kwargs)
if self.return_attention_mask:
attention_mask = out["attention_mask"]
out = out["input_ids"]
if device is not None and out.device != device:
out = out.to(device)
if self.return_attention_mask:
attention_mask = attention_mask.to(device)
if self.return_attention_mask:
return out, attention_mask
return out
def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
# Override _inv_call to account for ragged dims
if not self.in_keys_inv:
return tensordict
for in_key, out_key in _zip_strict(self.in_keys_inv, self.out_keys_inv):
data = tensordict.get(out_key, None, as_padded_tensor=True)
if data is not None:
item = self._inv_apply_transform(data)
tensordict.set(in_key, item)
elif not self.missing_tolerance:
raise KeyError(f"'{out_key}' not found in tensordict {tensordict}")
return tensordict
def call_tokenizer_inv_fn(self, value: Tensor):
if value.ndim == 1:
out = self.tokenizer.decode(
value.int(), skip_special_tokens=self.skip_special_tokens
)
else:
out = self.tokenizer.batch_decode(
value.int(), skip_special_tokens=self.skip_special_tokens
)
device = self._str_device
if isinstance(out, list):
result = NonTensorStack(*out)
if device:
result = result.to(device)
return result
return NonTensorData(out, device=device)
@property
def _str_device(self):
parent = self.parent
if parent is None:
return None
if self.in_keys:
in_key = self.in_keys[0]
elif self.in_keys_inv:
in_key = self.in_keys_inv[0]
else:
return None
if in_key in parent.observation_keys:
return parent.full_observation_spec[in_key].device
if in_key in parent.action_keys:
return parent.full_action_spec[in_key].device
if in_key in parent.state_keys:
return parent.full_state_spec[in_key].device
return None
transform_output_spec = Transform.transform_output_spec
transform_reward_spec = Transform.transform_reward_spec
transform_done_spec = Transform.transform_done_spec
[docs]
class IncrementalTokenizer(Transform):
"""Maintains tokens synchronized with history for token-first LLM inference.
This transform keeps ``tokens.prompt`` in sync with ``history.prompt``, enabling
LLM wrappers to use existing tokens directly instead of re-tokenizing. This
ensures KV cache consistency across multi-turn conversations.
**How it works:**
- **On reset**: Tokenizes ``history.prompt`` and stores in ``tokens.prompt``.
- **On step**: Reuses ``tokens.full`` from the LLM wrapper output as the new
``tokens.prompt``. Since ``ChatEnv`` sets ``next.history.prompt = history.full``,
and the LLM wrapper already produced ``tokens.full`` (tokenized ``history.full``),
no re-tokenization is needed - we simply copy ``tokens.full`` to ``next.tokens.prompt``.
This approach is both efficient (avoids redundant tokenization) and ensures perfect
token consistency (the exact same tokens are reused).
Args:
tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use for encoding.
Keyword Args:
history_key (NestedKey): Key for the history in the tensordict.
Defaults to ``("history", "prompt")``.
tokens_key (NestedKey): Key for storing tokens in the tensordict.
Defaults to ``("tokens", "prompt")``.
chat_template_name (str, optional): Name of the chat template to use.
Defaults to ``None`` (uses tokenizer's default).
chat_template (str, optional): Custom chat template string.
Defaults to ``None``.
add_generation_prompt (bool): Whether to add generation prompt when tokenizing.
Defaults to ``True``.
Example:
>>> from torchrl.envs.llm import ChatEnv
>>> from torchrl.envs.llm.transforms import IncrementalTokenizer
>>> from torchrl.envs import TransformedEnv
>>> from transformers import AutoTokenizer
>>>
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
>>> env = ChatEnv(batch_size=(1,), tokenizer=tokenizer)
>>> env = TransformedEnv(env, IncrementalTokenizer(tokenizer))
>>>
>>> # After reset and step, tokens.prompt will be maintained
>>> td = env.reset(TensorDict({"query": "Hello"}, batch_size=(1,)))
>>> assert ("tokens", "prompt") in td.keys(True, True)
.. note::
This transform is automatically added by :class:`~torchrl.envs.llm.ChatEnv`
when ``with_tokenizer=True`` is passed to the constructor.
.. warning::
**TODO**: Add validation that tokens match history (hash or length check).
For now, we trust that tokens are kept in sync. If you manually modify the
history, clear the tokens field to trigger re-tokenization.
See Also:
:class:`~torchrl.modules.llm.policies.vLLMWrapper`: Uses ``prefer_tokens=True``
to leverage tokens maintained by this transform.
:class:`~torchrl.modules.llm.policies.TransformersWrapper`: Uses ``prefer_tokens=True``
to leverage tokens maintained by this transform.
"""
def __init__(
self,
tokenizer: transformers.PreTrainedTokenizer, # noqa: F821
*,
history_key: NestedKey = ("history", "prompt"),
tokens_key: NestedKey = ("tokens", "prompt"),
chat_template_name: str | None = None,
chat_template: str | None = None,
add_generation_prompt: bool = True,
):
super().__init__()
self.tokenizer = tokenizer
self.history_key = history_key
self.tokens_key = tokens_key
self.chat_template_name = chat_template_name
self.chat_template = chat_template
self.add_generation_prompt = add_generation_prompt
def _tokenize_history(
self,
history: Any, # History object
add_generation_prompt: bool | None = None,
) -> list[Tensor]:
"""Tokenize a history object and return list of token tensors.
Args:
history: The History object to tokenize.
add_generation_prompt: Whether to add generation prompt. If None, uses
the instance's default.
Returns:
List of token tensors (one per batch element).
"""
if add_generation_prompt is None:
add_generation_prompt = self.add_generation_prompt
tokenizer_kwargs = {
"tokenize": True,
"padding": False,
"return_dict": True,
"add_generation_prompt": add_generation_prompt,
}
if self.chat_template_name is not None:
tokenizer_kwargs["chat_template_name"] = self.chat_template_name
if self.chat_template is not None:
tokenizer_kwargs["chat_template"] = self.chat_template
result = history.apply_chat_template(
tokenizer=self.tokenizer,
**tokenizer_kwargs,
)
# Get input_ids as list of tensors
tokens_list = result.get("input_ids", as_list=True)
return tokens_list
def _get_history(self, tensordict: TensorDictBase) -> Any | None:
"""Get history from tensordict, handling both nested keys and tensorclass access."""
history_key = self.history_key
if isinstance(history_key, tuple) and len(history_key) == 2:
# Try to access via tensorclass attribute pattern (e.g., tensordict["history"].prompt)
container_key, attr_key = history_key
container = tensordict.get(container_key, None)
if container is not None and hasattr(container, attr_key):
return getattr(container, attr_key)
# Fall back to regular nested key access
return tensordict.get(history_key, None)
def _reset(
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
) -> TensorDictBase:
"""Tokenize full history on reset."""
history = self._get_history(tensordict_reset)
if history is None:
# No history to tokenize
return tensordict_reset
# Tokenize full history
tokens_list = self._tokenize_history(history)
# Store tokens in tensordict - handle batched case
self._set_tokens(tensordict_reset, tokens_list)
return tensordict_reset
def _step(
self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
) -> TensorDictBase:
"""Set tokens.prompt on step, reusing tokens.full when available.
The LLM wrapper produces tokens.full (tokenized history.full).
Since ChatEnv sets next.history.prompt = history.full, we can
directly use tokens.full as next.tokens.prompt without re-tokenization.
Falls back to tokenizing next.history.prompt if tokens.full is not available.
"""
# Try to reuse tokens.full from the action tensordict
# Since next.history.prompt = history.full, tokens.full is already the correct tokenization
tokens_full_key = (
(self.tokens_key[0], "full")
if isinstance(self.tokens_key, tuple)
else "tokens_full"
)
existing_tokens_full = tensordict.get(tokens_full_key, None)
if existing_tokens_full is not None:
# Reuse tokens.full as the new tokens.prompt - no tokenization needed!
next_tensordict.set(self.tokens_key, existing_tokens_full)
return next_tensordict
# Fallback: tokenize next.history.prompt if tokens.full was not available
history = self._get_history(next_tensordict)
if history is None:
return next_tensordict
tokens_list = self._tokenize_history(history)
self._set_tokens(next_tensordict, tokens_list)
return next_tensordict
def _set_tokens(self, tensordict: TensorDictBase, tokens_list: list) -> None:
"""Set tokens in tensordict, handling variable-length tokens properly.
Uses nested tensors for storing variable-length token sequences,
which is compatible with as_list=True retrieval.
"""
# Store as nested tensor which handles variable-length sequences
tokens_nested = torch.nested.as_nested_tensor(tokens_list)
tensordict.set(self.tokens_key, tokens_nested)