Shortcuts

add_chat_template

class torchrl.data.llm.add_chat_template(template_name: str, template: str, inverse_parser: callable | None = None, model_family_keywords: list[str] | None = None)[source]

Add a custom chat template to the global template dictionary.

This function allows you to add custom chat templates for new model families that support assistant token masking via the {% generation %} keyword.

Parameters:
  • template_name (str) – The name of the template (e.g., “llama”, “mistral”). This name will be used in the chat_template_name parameter of History.apply_chat_template() and History.from_text().

  • template (str) – The Jinja2 template string. Must include {% generation %} blocks around assistant message content to enable token masking.

  • inverse_parser (callable, optional) – A function that parses formatted text back into a History object. Should have signature (text: str) -> History. If None, a basic parser will be used.

  • model_family_keywords (list[str], optional) – Keywords to detect this model family in the auto-detection logic. For example, [“llama”, “meta-llama”] for Llama models. If provided, the template will be automatically selected for models containing these keywords in their name.

Example

>>> from torchrl.data.llm.chat import add_chat_template, History
>>> from transformers import AutoTokenizer
>>>
>>> # Add a custom template for Llama models
>>> llama_template = '''
... {% for message in messages %}
... {%- if message['role'] == 'user' %}
... {{ '<s>[INST] ' + message['content'] + ' [/INST]' }}
... {%- elif message['role'] == 'assistant' %}
... {% generation %}{{ message['content'] + '</s>' }}{% endgeneration %}
... {%- endif %}
... {% endfor %}
... {%- if add_generation_prompt %}
... {% generation %}{{ ' ' }}{% endgeneration %}
... {%- endif %}
... '''
>>>
>>> def parse_llama_text(text: str) -> History:
...     # Custom parser for Llama format
...     import re
...     pattern = r'<s>\[INST\]\s*(.*?)\s*\[/INST\]\s*(.*?)</s>'
...     matches = re.findall(pattern, text, re.DOTALL)
...     messages = []
...     for user_content, assistant_content in matches:
...         messages.append(History(role="user", content=user_content.strip()))
...         messages.append(History(role="assistant", content=assistant_content.strip()))
...     return lazy_stack(messages)
>>>
>>> # Add the template with auto-detection
>>> add_chat_template(
...     template_name="llama",
...     template=llama_template,
...     inverse_parser=parse_llama_text,
...     model_family_keywords=["llama", "meta-llama"]
... )
>>>
>>> # Now you can use it with auto-detection
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
>>> history = History.from_chats([[
...     {"role": "user", "content": "Hello"},
...     {"role": "assistant", "content": "Hi there!"}
... ]])
>>>
>>> # Auto-detection will use the llama template
>>> result = history.apply_chat_template(
...     tokenizer=tokenizer,
...     add_generation_prompt=False,
...     return_dict=True,
...     return_assistant_tokens_mask=True,
... )
>>>
>>> # Or use it explicitly
>>> result = history.apply_chat_template(
...     tokenizer=tokenizer,
...     chat_template_name="llama",
...     add_generation_prompt=False,
...     return_dict=True,
...     return_assistant_tokens_mask=True,
... )

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources