• Docs >
  • Compiling LLM models from Huggingface
Shortcuts

Compiling LLM models from Huggingface

This tutorial walks you through how to compile LLM models from Huggingface using Torch-TensorRT. We also introduce KV caching in Torch-TensorRT which can greatly improve the performance of LLM inference. The code is available in the tools/llm directory. We use the run_llm.py script to compile the model, generate outputs, and measure the performance.

Note

This is an experimental release and APIs may change in future versions.

Note

The compilation scripts and tutorials for Llama-2-7b-chat-hf and gpt2 models have been consolidated into the unified run_llm.py script located in the tools/llm directory.

Overview of tools/llm Directory

The tools/llm directory provides the following tools to compile LLM models from Huggingface:

  • run_llm.py: Main entry point for model compilation, generating outputs, and benchmarking

  • Static Cache Utilities: static_cache_v1.py and static_cache_v2.py for KV cache optimization

  • SDPA Attention: sdpa_converter.py and register_sdpa.py for registering scaled dot-product attention converter and lowering pass.

  • Testing Components: Model-specific test files for validation

  • Utility Functions: utils.py and cache_utils.py for common operations

Supported Models

We have officially verified support for the following LLM families:

Model Series

HuggingFace Model Card

Precision

KV Cache Support ?

GPT-2

gpt2

FP16, FP32

Yes

LLaMA 2

meta-llama/Llama-2-7b-chat-hf

FP16, FP32

Yes

LLaMA 3.1

meta-llama/Llama-3.1-8B-Instruct

FP16, FP32

Yes

LLaMA 3.2

meta-llama/Llama-3.2-1B-Instruct
meta-llama/Llama-3.2-3B-Instruct

FP16, FP32

Yes

Qwen 2.5

Qwen/Qwen2.5-0.5B-Instruct
Qwen/Qwen2.5-1.5B-Instruct
Qwen/Qwen2.5-3B-Instruct
Qwen/Qwen2.5-7B-Instruct

FP16, FP32

Yes

Getting Started with run_llm.py

The main entry point is run_llm.py, which provides a complete workflow for model compilation and benchmarking.

Basic Usage

python tools/llm/run_llm.py \
  --model meta-llama/Llama-3.2-1B-Instruct \
  --prompt "What is parallel programming?" \
  --precision FP16 \
  --num_tokens 128 \
  --cache static_v2 \
  --benchmark

Key Arguments

  • --model: Name or path of the HuggingFace LLM

  • --tokenizer: (Optional) Tokenizer name; defaults to model name

  • --prompt: Input prompt for text generation

  • --precision: Precision mode (FP16, FP32)

  • --num_tokens: Number of output tokens to generate

  • --cache: KV cache type (static_v1, static_v2, or empty for no KV caching)

  • --benchmark: Enable benchmarking mode for performance comparison

  • --enable_pytorch_run: Also run and compare PyTorch baseline

Other Usage Examples

# Compare different models performance
python tools/llm/run_llm.py --model gpt2 --benchmark --enable_pytorch_run
python tools/llm/run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --benchmark --enable_pytorch_run

# Generate the outputs (disable benchmarking) by specifying the number of tokens to generate. Default = 128
python tools/llm/run_llm.py --model gpt2 --prompt "What is parallel programming?" --num_tokens 128
python tools/llm/run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --prompt "What is parallel programming?" --num_tokens 128

# Test different caching approaches
python tools/llm/run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --cache static_v1
python tools/llm/run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --cache static_v2

# Compare FP16 vs FP32 performance
python tools/llm/run_llm.py --model Qwen/Qwen2.5-1.5B-Instruct --precision FP16 --benchmark
python tools/llm/run_llm.py --model Qwen/Qwen2.5-1.5B-Instruct --precision FP32 --benchmark

KV Caching in Torch-TensorRT

We provide two versions of static KV caching: static_cache_v1 and static_cache_v2. In both implementations, we add static KV cache tensors as model inputs/outputs without storing them as external memory. The length of KV cache = input sequence length + output sequence length (specified by --num_tokens). The number of heads and head dimension are determined by the model config.

Static Cache v1

The static_cache_v1.py implements KV cache in the model graph as follows:

class StaticCacheV1Model(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True):
        # Concatenate new key/value pairs with existing cache
        new_key_cache = torch.cat((key_cache[:, :, :start_idx, :], k, key_cache[:, :, end_idx:, :]), dim=2)
        new_value_cache = torch.cat((value_cache[:, :, :start_idx, :], v, value_cache[:, :, end_idx:, :]), dim=2)

        # Compute attention using the updated cache
        attn_output = torch._C._nn.scaled_dot_product_attention(
            q,
            new_key_cache[:, :, :end_idx, :],
            new_value_cache[:, :, :end_idx, :],
            dropout_p=0.0,
            is_causal=is_causal
        )

        return attn_output, new_key_cache, new_value_cache

In the above code, we concatenate the new key/value pairs with the existing cache and update it. To compute the attention, we use the updated cache and gather the corresponding keys/values from the cache up until and including the current token index. The above code is actually implemented as a FX graph transformation pass. We register it as a Torch-TensorRT lowering pass using the decorator @_aten_lowering_pass when we import the static_cache_v1.py module.

Note

The start_idx and end_idx are the start and end indices of the current token in the cache. For prefill phase, start_idx is 0 and end_idx is the input sequence length. For decode phase, start_idx begins at the input sequence length and end_idx equals start_idx + 1. The start_idx is incremented by 1 until the end of the sequence or we reach the maximum number of tokens to generate.

Static Cache v2

The static_cache_v2.py is similar to static_cache_v1.py but it uses less number of slice operations. It implements KV cache in the model graph as follows:

class StaticCacheV2Model(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True):
        concat_keys = torch.cat((key_cache[:, :, :start_idx, :], k), dim=2)
        concat_values = torch.cat((value_cache[:, :, :start_idx, :], v), dim=2)
        new_key_cache = torch.cat((concat_keys, key_cache[:, :, end_idx:, :]), dim=2)
        new_value_cache = torch.cat((concat_values, value_cache[:, :, end_idx:, :]), dim=2)
        attn_output = torch._C._nn.scaled_dot_product_attention(
              q, concat_keys, concat_values, dropout_p=0.0, is_causal=is_causal
        )

        return attn_output, new_key_cache, new_value_cache

In the above code, we concatenate the existing key/value cache with current key/value of the token. We use this to directly compute the attention and update the key/value cache inserting the current key/value. The above code is actually implemented as a FX graph transformation pass. We register it as a Torch-TensorRT lowering pass using the decorator @_aten_lowering_pass when we import the static_cache_v1.py module. The definitons of start_idx and end_idx are the same as static_cache_v1.py.

After the model is compiled with static KV cache, the input signature of the model is changed. The new input signature is (input_ids, position_ids, key_cache_0, value_cache_0, ..., start_idx, end_idx). The number of key/value cache tensors is equal to the number of attention heads in the model. We can use the generate_with_static_cache function to generate the outputs.

Generating Outputs

We use custom generate function to generate the outputs. This function performs standard autoregressive decoding without KV caching. There is also a generate_with_static_cache function that performs autoregressive decoding with KV caching.

The generate_with_static_cache function takes care of preparing the inputs to the model compiled with static KV cache. The model inputs are input_ids, position_ids, key_cache_0, value_cache_0, …., start_idx, end_idx. We initialize the key/value cache tensors with zeros and for every token generated, the new key/value cache tensors are the outputs of the model.

SDPA Converter (sdpa_converter.py)

  • Converts scaled dot-product attention operation using TRT Python API.

  • Supports causal and standard self-attention.

SDPA Registration (register_sdpa.py)

  • This is a Torch-TensorRT lowering pass that replaces variants of SDPA with torch.nn.functional.scaled_dot_product_attention.

  • Registers the SDPA converter which is used for converting torch.nn.functional.scaled_dot_product_attention operation.

Limitations and Known Issues

  • Sliding window attention (used in Gemma3 and Qwen 3 models) is not yet supported

  • Some model architectures (e.g. Phi-4) have issues with exporting the torch model.

Requirements

  • Torch-TensorRT 2.8.0 or later

  • Transformers v4.52.3

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources