Multiple Optimization Profiles (Prefill / Decode)#

TensorRT tunes kernels for the optimization profile of an engine: a [min, opt, max] range for every dynamic input dimension. Kernels are tuned at the opt point, so a single profile can only be optimal for one shape.

Many models, however, run in several distinct shape regimes that share the same weights. The canonical case is an autoregressive LLM:

  • prefill – the prompt is processed in one shot, so seq is large, and

  • decode – tokens are generated one at a time, so seq == 1.

With a single dynamic range seq in [1, max] you must pick one opt. Tuning for the long prefill length leaves decode – the latency-critical, most frequently executed phase – running on kernels chosen for a sequence length it never sees.

Torch-TensorRT lets you declare multiple optimization profiles on a single input and select the active one at runtime. The engine is built once and each profile is tuned independently.

Declaring profiles#

Pass an ordered list of {"min_shape", "opt_shape", "max_shape"} dicts to torch_tensorrt.Input via profiles. The list index is the optimization-profile index you select at runtime.

import torch
import torch_tensorrt

DECODE_IDX, PREFILL_IDX = 0, 1

profiled_input = torch_tensorrt.Input(
    dtype=torch.int64,
    profiles=[
        # index 0 -> decode: seq pinned to 1 (a fully static profile)
        {"min_shape": (1, 1), "opt_shape": (1, 1), "max_shape": (1, 1)},
        # index 1 -> prefill: seq in [1, 512], tuned at 256
        {"min_shape": (1, 1), "opt_shape": (1, 256), "max_shape": (1, 512)},
    ],
)

When using profiles, do not also pass the single-range min_shape / opt_shape / max_shape arguments, or the static shape argument, for the same input. profiles already contains the shape ranges for that input.

The union envelope#

torch.export traces a model over one [min, opt, max] range, so Input automatically derives the union envelope of all profiles (elementwise minimum of every min_shape and maximum of every max_shape; opt_shape is taken from the first profile). Each declared profile is a subset of this envelope. You export over the envelope and the individual profiles become the per-profile TensorRT tunings:

print(profiled_input.shape["min_shape"])  # (1, 1)
print(profiled_input.shape["max_shape"])  # (1, 512)

Compile#

Export once over the union range, then compile as usual. Every input that declares profiles must declare the same number of profiles; static inputs (or dynamic inputs without profiles) reuse their single shape in every profile.

seq = torch.export.Dim("seq", min=1, max=512)
exported = torch.export.export(model, (example_ids,), dynamic_shapes=({1: seq},))

trt_model = torch_tensorrt.dynamo.compile(
    exported,
    arg_inputs=[profiled_input],
    enabled_precisions={torch.float16},
    min_block_size=1,
)

Selecting a profile at runtime#

Selection is manual by default. Use the torch_tensorrt.runtime.optimization_profile() context manager to pin a profile by index for the duration of a with block; the prior state is saved on enter and restored on exit, so blocks nest cleanly.

from torch_tensorrt.runtime import optimization_profile

with optimization_profile(trt_model, DECODE_IDX):
    logits = trt_model(decode_ids)      # seq == 1

with optimization_profile(trt_model, PREFILL_IDX):
    logits = trt_model(prefill_ids)     # seq == 256

Pass "auto" to let Torch-TensorRT choose from the input shapes. Auto-selection is lazy / first-working: it scans profiles in index order and uses the first whose [min, max] contains the input. Order matters when profiles overlap – declaring decode first lets it win the seq == 1 overlap:

with optimization_profile(trt_model, "auto"):
    trt_model(decode_ids)    # seq == 1   -> index 0 (decode) accepts -> decode
    trt_model(prefill_ids)   # seq == 256 -> index 0 rejects -> index 1 (prefill)

Profiles, graph breaks, and serialization#

  • Graph breaks: when a model is partitioned into several TensorRT engines, every engine carries the same number of profiles. Torch-TensorRT propagates the per-profile bounds across the break, evaluating any derived dynamic dimension (e.g. a reshape that turns seq into 16 * seq) through to the downstream engine, so runtime selection stays consistent for the whole module.

  • Serialization / runtimes: profile state is reconstructed from the TensorRT API on load (getNbOptimizationProfiles / getProfileShape), so a serialized engine keeps its profiles with no extra metadata. The same optimization_profile API drives both the C++ and Python runtimes, which remain interchangeable.

Why it helps: a worked latency example#

The example multi_optimization_profiles compiles google/gemma-3-1b-it twice – once with a single profile (tuned at the prefill length) and once with separate decode/prefill profiles – then compares per-call latency. The multi-profile engine dedicates a static profile (seq pinned to 1) to decode, letting TensorRT specialize that path (measured on an NVIDIA A40, FP16):

Per-call latency (ms), batch=1
regime                single-profile   multi-profile   speedup
--------------------------------------------------------------
decode (seq=1)                 5.232           4.597     1.14x
prefill (seq=128)              7.152           7.534     0.95x

Prefill is essentially unchanged (both engines tune it at the same opt), while decode – the regime executed once per generated token – is faster. Exact numbers depend on the model and GPU; the takeaway is that one engine can be tuned well for both regimes instead of compromising on a single opt shape.

Note

Because the model has two dynamic inputs (input_ids and position_ids), the example passes one profiled Input for each, both declaring the same profiles. The HuggingFace attention path also needs a TensorRT-friendly SDPA lowering (tools/llm/torchtrt_ext/register_sdpa), and gemma-3-1b-it is a gated model requiring Hugging Face authentication.

See also

  • Runnable example: multi_optimization_profiles

  • torch_tensorrt.Input

  • torch_tensorrt.runtime.optimization_profile()