.._llama2_flashinfer_rmsnorm:

Automatically generate a TensorRT Plugin for RMSNorm module and apply it in Llama2#

This example showcases how to optimize inference for a LLaMA2 model by replacing its RMSNorm layers with FlashInfer’s high-performance implementation. It demonstrates the use of Torch-TensorRT’s automatic plugin feature, which dynamically generates and integrates custom TensorRT plugins during compilation.

Key features: - Leverages automatic plugin registration for FlashInfer RMSNorm ops. - Applies a custom TorchDynamo lowering pass to replace standard RMSNorm ops. - Compiles the modified model using Torch-TensorRT’s Dynamo path. - Benchmarks inference performance with and without FlashInfer RMSNorm.

This example illustrates advanced extensibility in Torch-TensorRT through automatic plugin generation and operator lowering customization.

[ ]:
from typing import Callable, Optional, Sequence, Union

import flashinfer
import torch
import torch_tensorrt
from torch.fx.passes.shape_prop import TensorMetadata
from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import (
    _aten_lowering_pass,
)
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
    clean_up_graph_after_modifications,
)
from transformers import LlamaConfig, LlamaForCausalLM


@torch.library.custom_op("flashinfer::rmsnorm", mutates_args=())  # type: ignore[misc]
def flashinfer_rmsnorm(
    input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
) -> torch.Tensor:
    return flashinfer.norm.rmsnorm(input, weight)


@torch.library.register_fake("flashinfer::rmsnorm")
def _(input: torch.Tensor, weight: torch.Tensor, b: float = 1e-6) -> torch.Tensor:
    return input


torch_tensorrt.dynamo.conversion.plugins.custom_op(
    "flashinfer::rmsnorm", supports_dynamic_shapes=True
)


@_aten_lowering_pass
def replace_rmsnorm(
    gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
) -> torch.fx.GraphModule:
    for node in gm.graph.nodes:
        if (
            node.target == torch.ops.aten._to_copy.default
            and node.kwargs.get("dtype") is torch.float32
            and len(node.users) == 2
        ):
            if (
                list(node.users)[0].target == torch.ops.aten.pow.Tensor_Scalar
                and list(node.users)[1].target == torch.ops.aten.mul.Tensor
            ):
                pow_node = list(node.users)[0]
                if (
                    len(pow_node.users) == 1
                    and list(pow_node.users)[0].target == torch.ops.aten.mean.dim
                ):
                    mean_node = list(pow_node.users)[0]
                    if (
                        len(mean_node.users) == 1
                        and list(mean_node.users)[0].target == torch.ops.aten.add.Tensor
                    ):
                        add_node = list(mean_node.users)[0]
                        if (
                            len(add_node.users) == 1
                            and list(add_node.users)[0].target
                            == torch.ops.aten.sqrt.default
                        ):
                            sqrt_node = list(add_node.users)[0]
                            if (
                                len(sqrt_node.users) == 1
                                and list(sqrt_node.users)[0].target
                                == torch.ops.aten.div.Tensor
                            ):
                                div_node = list(sqrt_node.users)[0]
                                if list(div_node.users)[0] == list(node.users)[1]:
                                    mul_node = list(div_node.users)[0]
                                    copy_node = list(mul_node.users)[0]
                                    weight_mul_node = list(copy_node.users)[0]

                                    weight = weight_mul_node.args[0]

                                    original_meta = weight_mul_node.meta.get(
                                        "tensor_meta", {}
                                    )
                                    memory_format = original_meta.memory_format

                                    with gm.graph.inserting_after(weight_mul_node):
                                        b = gm.graph.create_node(
                                            op="call_function",
                                            target=torch.ops.aten.sym_size.int,
                                            args=(node.args[0], 0),
                                        )
                                        b.meta["tensor_meta"] = TensorMetadata(
                                            shape=torch.Size([1]),
                                            dtype=torch.int64,
                                            requires_grad=False,
                                            stride=None,
                                            memory_format=memory_format,
                                            is_quantized=False,
                                            qparams={},
                                        )
                                        s = gm.graph.create_node(
                                            op="call_function",
                                            target=torch.ops.aten.sym_size.int,
                                            args=(node.args[0], 1),
                                        )
                                        s.meta.update(b.meta)

                                        d = gm.graph.create_node(
                                            op="call_function",
                                            target=torch.ops.aten.sym_size.int,
                                            args=(node.args[0], 2),
                                        )
                                        d.meta.update(b.meta)

                                    with gm.graph.inserting_after(b):
                                        new_first_dim = gm.graph.create_node(
                                            op="call_function",
                                            target=torch.ops.aten.mul.Scalar,
                                            args=(b, s),
                                        )
                                        new_first_dim.meta.update(b.meta)

                                    with gm.graph.inserting_after(new_first_dim):
                                        # with gm.graph.inserting_after(weight_mul_node):
                                        reshape_node = gm.graph.create_node(
                                            op="call_function",
                                            target=torch.ops.aten.reshape.default,
                                            args=(node.args[0], [new_first_dim, d]),
                                        )
                                        b_val = original_meta.shape[0]
                                        s_val = original_meta.shape[1]
                                        d_val = original_meta.shape[2]

                                        reshape_node.meta["tensor_meta"] = (
                                            TensorMetadata(
                                                shape=torch.Size(
                                                    [b_val * s_val, d_val]
                                                ),
                                                dtype=original_meta.dtype,
                                                requires_grad=True,
                                                stride=None,
                                                memory_format=memory_format,
                                                is_quantized=False,
                                                qparams={},
                                            )
                                        )

                                    with gm.graph.inserting_after(reshape_node):
                                        flashinfer_rmsnorm_node = gm.graph.create_node(
                                            op="call_function",
                                            target=torch.ops.flashinfer.rmsnorm.default,
                                            args=(
                                                reshape_node,
                                                weight,
                                                add_node.args[1],
                                            ),
                                        )
                                        flashinfer_rmsnorm_node.meta.update(
                                            reshape_node.meta
                                        )

                                    with gm.graph.inserting_after(
                                        flashinfer_rmsnorm_node
                                    ):
                                        reshapback_node = gm.graph.create_node(
                                            op="call_function",
                                            target=torch.ops.aten.reshape.default,
                                            args=(
                                                flashinfer_rmsnorm_node,
                                                [b, s, d],
                                            ),
                                        )

                                    weight_mul_node.replace_all_uses_with(
                                        reshapback_node
                                    )
                                    reshapback_node.meta.update(weight_mul_node.meta)

                                    modified_graph = True

                                    gm.graph.erase_node(weight_mul_node)
                                    gm.graph.erase_node(copy_node)
                                    gm.graph.erase_node(mul_node)
                                    gm.graph.erase_node(div_node)
                                    gm.graph.erase_node(sqrt_node)
                                    gm.graph.erase_node(add_node)
                                    gm.graph.erase_node(mean_node)
                                    gm.graph.erase_node(pow_node)
                                    gm.graph.erase_node(node)

    if modified_graph:
        gm = clean_up_graph_after_modifications(gm)

    return gm


# 1. Create a custom config with 1 layer
config = LlamaConfig(
    vocab_size=32000,
    hidden_size=4096,  # LLaMA2-7B dimensions
    intermediate_size=11008,  # FFN hidden_dim = 4 * 4096 * 0.7 (SwiGLU scaling)
    num_hidden_layers=1,  # Only 1 decoder layer
    num_attention_heads=32,
    max_position_embeddings=4096,
    use_cache=False,  # Disable KV caching for export
)

# 2. Initialize model (random weights)
with torch.no_grad():
    model = LlamaForCausalLM(config).cuda().half().eval()

# 3. Export with static shapes
input_ids = torch.randint(0, 32000, (1, 64))  # Static [batch=1, seq=64]
exported = torch.export.export(
    model,
    (input_ids,),
    dynamic_shapes=None,  # Fully static
)

# Test forward pass
input_ids = torch.randint(0, 32000, (1, 64))
output = model(input_ids)
print(output)

# Export validation

DEVICE = torch.device("cuda:0")

with torch_tensorrt.logging.errors():
    trt_model = torch_tensorrt.dynamo.compile(
        exported,
        inputs=[input_ids],
        enabled_precisions={torch.float32, torch.float16},
        truncate_double=True,
        device=DEVICE,
        disable_tf32=True,
        use_explicit_typing=False,
        use_fp32_acc=True,
    )

input_ids = input_ids.to(DEVICE)

with torch.no_grad():
    res = trt_model.forward(input_ids)
print(res)