{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n.._llama2_flashinfer_rmsnorm:\n\n# Automatically generate a TensorRT Plugin for RMSNorm module and apply it in Llama2\n\nThis example showcases how to optimize inference for a LLaMA2 model by replacing its RMSNorm layers with FlashInfer's high-performance implementation. It demonstrates the use of Torch-TensorRT's automatic plugin feature, which dynamically generates and integrates custom TensorRT plugins during compilation.\n\nKey features:\n- Leverages automatic plugin registration for FlashInfer RMSNorm ops.\n- Applies a custom TorchDynamo lowering pass to replace standard RMSNorm ops.\n- Compiles the modified model using Torch-TensorRT's Dynamo path.\n- Benchmarks inference performance with and without FlashInfer RMSNorm.\n\nThis example illustrates advanced extensibility in Torch-TensorRT through automatic plugin generation and operator lowering customization.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from typing import Callable, Optional, Sequence, Union\n\nimport flashinfer\nimport torch\nimport torch_tensorrt\nfrom torch.fx.passes.shape_prop import TensorMetadata\nfrom torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import (\n _aten_lowering_pass,\n)\nfrom torch_tensorrt.dynamo.lowering.passes.pass_utils import (\n clean_up_graph_after_modifications,\n)\nfrom transformers import LlamaConfig, LlamaForCausalLM\n\n\n@torch.library.custom_op(\"flashinfer::rmsnorm\", mutates_args=()) # type: ignore[misc]\ndef flashinfer_rmsnorm(\n input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6\n) -> torch.Tensor:\n return flashinfer.norm.rmsnorm(input, weight)\n\n\n@torch.library.register_fake(\"flashinfer::rmsnorm\")\ndef _(input: torch.Tensor, weight: torch.Tensor, b: float = 1e-6) -> torch.Tensor:\n return input\n\n\ntorch_tensorrt.dynamo.conversion.plugins.custom_op(\n \"flashinfer::rmsnorm\", supports_dynamic_shapes=True\n)\n\n\n@_aten_lowering_pass\ndef replace_rmsnorm(\n gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]\n) -> torch.fx.GraphModule:\n for node in gm.graph.nodes:\n if (\n node.target == torch.ops.aten._to_copy.default\n and node.kwargs.get(\"dtype\") is torch.float32\n and len(node.users) == 2\n ):\n if (\n list(node.users)[0].target == torch.ops.aten.pow.Tensor_Scalar\n and list(node.users)[1].target == torch.ops.aten.mul.Tensor\n ):\n pow_node = list(node.users)[0]\n if (\n len(pow_node.users) == 1\n and list(pow_node.users)[0].target == torch.ops.aten.mean.dim\n ):\n mean_node = list(pow_node.users)[0]\n if (\n len(mean_node.users) == 1\n and list(mean_node.users)[0].target == torch.ops.aten.add.Tensor\n ):\n add_node = list(mean_node.users)[0]\n if (\n len(add_node.users) == 1\n and list(add_node.users)[0].target\n == torch.ops.aten.sqrt.default\n ):\n sqrt_node = list(add_node.users)[0]\n if (\n len(sqrt_node.users) == 1\n and list(sqrt_node.users)[0].target\n == torch.ops.aten.div.Tensor\n ):\n div_node = list(sqrt_node.users)[0]\n if list(div_node.users)[0] == list(node.users)[1]:\n mul_node = list(div_node.users)[0]\n copy_node = list(mul_node.users)[0]\n weight_mul_node = list(copy_node.users)[0]\n\n weight = weight_mul_node.args[0]\n\n original_meta = weight_mul_node.meta.get(\n \"tensor_meta\", {}\n )\n memory_format = original_meta.memory_format\n\n with gm.graph.inserting_after(weight_mul_node):\n b = gm.graph.create_node(\n op=\"call_function\",\n target=torch.ops.aten.sym_size.int,\n args=(node.args[0], 0),\n )\n b.meta[\"tensor_meta\"] = TensorMetadata(\n shape=torch.Size([1]),\n dtype=torch.int64,\n requires_grad=False,\n stride=None,\n memory_format=memory_format,\n is_quantized=False,\n qparams={},\n )\n s = gm.graph.create_node(\n op=\"call_function\",\n target=torch.ops.aten.sym_size.int,\n args=(node.args[0], 1),\n )\n s.meta.update(b.meta)\n\n d = gm.graph.create_node(\n op=\"call_function\",\n target=torch.ops.aten.sym_size.int,\n args=(node.args[0], 2),\n )\n d.meta.update(b.meta)\n\n with gm.graph.inserting_after(b):\n new_first_dim = gm.graph.create_node(\n op=\"call_function\",\n target=torch.ops.aten.mul.Scalar,\n args=(b, s),\n )\n new_first_dim.meta.update(b.meta)\n\n with gm.graph.inserting_after(new_first_dim):\n # with gm.graph.inserting_after(weight_mul_node):\n reshape_node = gm.graph.create_node(\n op=\"call_function\",\n target=torch.ops.aten.reshape.default,\n args=(node.args[0], [new_first_dim, d]),\n )\n b_val = original_meta.shape[0]\n s_val = original_meta.shape[1]\n d_val = original_meta.shape[2]\n\n reshape_node.meta[\"tensor_meta\"] = (\n TensorMetadata(\n shape=torch.Size(\n [b_val * s_val, d_val]\n ),\n dtype=original_meta.dtype,\n requires_grad=True,\n stride=None,\n memory_format=memory_format,\n is_quantized=False,\n qparams={},\n )\n )\n\n with gm.graph.inserting_after(reshape_node):\n flashinfer_rmsnorm_node = gm.graph.create_node(\n op=\"call_function\",\n target=torch.ops.flashinfer.rmsnorm.default,\n args=(\n reshape_node,\n weight,\n add_node.args[1],\n ),\n )\n flashinfer_rmsnorm_node.meta.update(\n reshape_node.meta\n )\n\n with gm.graph.inserting_after(\n flashinfer_rmsnorm_node\n ):\n reshapback_node = gm.graph.create_node(\n op=\"call_function\",\n target=torch.ops.aten.reshape.default,\n args=(\n flashinfer_rmsnorm_node,\n [b, s, d],\n ),\n )\n\n weight_mul_node.replace_all_uses_with(\n reshapback_node\n )\n reshapback_node.meta.update(weight_mul_node.meta)\n\n modified_graph = True\n\n gm.graph.erase_node(weight_mul_node)\n gm.graph.erase_node(copy_node)\n gm.graph.erase_node(mul_node)\n gm.graph.erase_node(div_node)\n gm.graph.erase_node(sqrt_node)\n gm.graph.erase_node(add_node)\n gm.graph.erase_node(mean_node)\n gm.graph.erase_node(pow_node)\n gm.graph.erase_node(node)\n\n if modified_graph:\n gm = clean_up_graph_after_modifications(gm)\n\n return gm\n\n\n# 1. Create a custom config with 1 layer\nconfig = LlamaConfig(\n vocab_size=32000,\n hidden_size=4096, # LLaMA2-7B dimensions\n intermediate_size=11008, # FFN hidden_dim = 4 * 4096 * 0.7 (SwiGLU scaling)\n num_hidden_layers=1, # Only 1 decoder layer\n num_attention_heads=32,\n max_position_embeddings=4096,\n use_cache=False, # Disable KV caching for export\n)\n\n# 2. Initialize model (random weights)\nwith torch.no_grad():\n model = LlamaForCausalLM(config).cuda().half().eval()\n\n# 3. Export with static shapes\ninput_ids = torch.randint(0, 32000, (1, 64)) # Static [batch=1, seq=64]\nexported = torch.export.export(\n model,\n (input_ids,),\n dynamic_shapes=None, # Fully static\n)\n\n# Test forward pass\ninput_ids = torch.randint(0, 32000, (1, 64))\noutput = model(input_ids)\nprint(output)\n\n# Export validation\n\nDEVICE = torch.device(\"cuda:0\")\n\nwith torch_tensorrt.logging.errors():\n trt_model = torch_tensorrt.dynamo.compile(\n exported,\n inputs=[input_ids],\n enabled_precisions={torch.float32, torch.float16},\n truncate_double=True,\n device=DEVICE,\n disable_tf32=True,\n use_explicit_typing=False,\n use_fp32_acc=True,\n )\n\ninput_ids = input_ids.to(DEVICE)\n\nwith torch.no_grad():\n res = trt_model.forward(input_ids)\nprint(res)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.14" } }, "nbformat": 4, "nbformat_minor": 0 }