{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n\n# Automatically Generate a Plugin for a Custom Kernel\n\nThis example demonstrates how to register a custom Triton kernel as a TensorRT plugin\nusing the TensorRT 10.7+ Quick Deployable Plugin (QDP) system, and how Torch-TensorRT\nautomatically generates the converter that wires the two together.\n\nWithout a plugin, a custom op would fall back to PyTorch at runtime, causing a graph\nbreak between two TRT subgraphs. The plugin approach runs the custom kernel *inside*\nthe TRT engine, avoiding that overhead entirely.\n\n**What \"automatically generate\" means here:**\n\n``generate_plugin`` uses PyTorch's FakeTensor/symbolic-shape machinery to introspect\nyour op's schema at registration time. It synthesizes:\n\n* A *shape descriptor* function (``_generic_plugin_desc``) that computes output shapes\n from symbolic input dimensions using ``lambdify`` expressions \u2014 this is how TRT knows\n output shapes without running the kernel.\n* A *JIT implementation* function (``_generic_plugin_impl``) that, at TRT engine\n runtime, converts TRT tensors back to PyTorch tensors, calls your op directly on the\n CUDA stream TRT provides, and copies results to the output buffers.\n\nBoth are registered in TensorRT's ``QDP_REGISTRY`` under ``\"torchtrt_ex::elementwise_scale_mul\"``.\n\n``generate_plugin_converter`` then creates and registers a\n``@dynamo_tensorrt_converter`` for ``torch.ops.torchtrt_ex.elementwise_scale_mul.default``\nin Torch-TensorRT's ``DYNAMO_CONVERTERS`` table. When the compiler encounters that op\nin the FX graph it calls this converter, which instantiates the QDP plugin and adds a\nplugin layer to the TRT ``INetworkDefinition``.\n\n**JIT vs AOT:** The plugin generated here is JIT \u2014 at TRT engine runtime, TRT calls\nback into Python to execute the Triton kernel via PyTorch. For a pre-compiled binary\nthat avoids the Python overhead see the `aot_plugin` example.\n\nSee also `custom_kernel_plugins` for the lower-level\n``IPluginV2DynamicExt`` approach that predates TRT 10.7.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 1: Define the Triton Kernel\n\nThe kernel itself is pure Triton \u2014 no TRT-specific code at this stage.\n``generate_plugin`` will later wrap it in a JIT implementation that TRT\ncan call at runtime.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from typing import Tuple\n\nimport tensorrt_bindings.plugin as trtp\nimport torch\nimport torch_tensorrt\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef elementwise_scale_mul_kernel(X, Y, Z, a, b, BLOCK_SIZE: tl.constexpr):\n pid = tl.program_id(0)\n # Compute the range of elements that this thread block will work on\n block_start = pid * BLOCK_SIZE\n # Range of indices this thread will handle\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n # Load elements from the X and Y tensors\n x_vals = tl.load(X + offsets)\n y_vals = tl.load(Y + offsets)\n # Perform the element-wise multiplication\n z_vals = x_vals * y_vals * a + b\n # Store the result in Z\n tl.store(Z + offsets, z_vals)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 2: Register the Op with PyTorch\n\n``@torch.library.custom_op`` registers the kernel as a first-class PyTorch op.\nThis is what lets you call it as ``torch.ops.torchtrt_ex.elementwise_scale_mul``\nin model forward passes and have ``torch.export`` trace through it.\n\n``@torch.library.register_fake`` registers the *meta-kernel* (also called a fake\nkernel or abstract impl). This function runs on ``FakeTensor`` objects \u2014 it must\nreturn a tensor of the correct *shape and dtype* without doing any actual compute.\nThree systems depend on it:\n\n* ``torch.export`` / Dynamo \u2014 for tracing shape propagation.\n* ``generate_plugin`` \u2014 it runs your meta-kernel symbolically with ``FakeTensorMode``\n to derive the output-shape expressions it embeds in the QDP shape descriptor.\n* Torch-TensorRT's partitioner \u2014 to decide whether the op can be included in a TRT\n subgraph.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "@torch.library.custom_op(\"torchtrt_ex::elementwise_scale_mul\", mutates_args=()) # type: ignore[misc]\ndef elementwise_scale_mul(\n X: torch.Tensor, Y: torch.Tensor, b: float = 0.2, a: int = 2\n) -> torch.Tensor:\n assert X.is_cuda and Y.is_cuda, \"Tensors must be on CUDA device.\"\n assert X.shape == Y.shape, \"Tensors must have the same shape.\"\n\n Z = torch.empty_like(X)\n BLOCK_SIZE = 1024\n grid = lambda meta: (X.numel() // meta[\"BLOCK_SIZE\"],)\n elementwise_scale_mul_kernel[grid](X, Y, Z, a, b, BLOCK_SIZE=BLOCK_SIZE)\n return Z\n\n\n@torch.library.register_fake(\"torchtrt_ex::elementwise_scale_mul\")\ndef _(x: torch.Tensor, y: torch.Tensor, b: float = 0.2, a: int = 2) -> torch.Tensor:\n # Elementwise \u2014 output has the same shape and dtype as the first input.\n return x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 3: Auto-Generate the TensorRT QDP Plugin\n\n``generate_plugin`` does the following internally:\n\n1. Calls your ``register_fake`` function with ``FakeTensor`` objects carrying\n symbolic ``SymInt`` shapes (via ``ShapeEnv``). This produces symbolic output-shape\n expressions like ``s0 * s1``.\n2. Turns those expressions into Python lambda functions with ``lambdify``, and\n builds a ``_generic_plugin_desc`` that computes TRT ``TensorDesc`` output shapes\n at graph-construction time.\n3. Builds a ``_generic_plugin_impl`` that TRT calls at engine *runtime*:\n it converts each TRT tensor handle to a ``torch.Tensor``, runs\n ``torch.ops.torchtrt_ex.elementwise_scale_mul`` on the provided CUDA stream,\n then copies results back to TRT's output buffers.\n4. Registers both under ``\"torchtrt_ex::elementwise_scale_mul\"`` in TensorRT's\n global ``QDP_REGISTRY``.\n\nAfter this call, ``trtp.op.torchtrt_ex.elementwise_scale_mul`` exists and TRT\nknows how to compute output shapes and execute the kernel.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "torch_tensorrt.dynamo.conversion.plugins.generate_plugin(\n \"torchtrt_ex::elementwise_scale_mul\"\n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 4: Auto-Generate the Torch-TensorRT Converter\n\n``generate_plugin_converter`` does the following internally:\n\n1. Looks up ``\"torchtrt_ex::elementwise_scale_mul\"`` in ``QDP_REGISTRY`` and checks\n whether an AOT implementation is registered (``desc.aot_impl_func``). Here there\n is none, so it uses the JIT path.\n2. Defines a converter function that, when called during TRT graph construction:\n a. Splits ``args`` into tensor inputs (converted to ``trt.ITensor`` via\n ``get_trt_tensor``) and non-tensor attributes (scalars, passed as plugin attrs).\n b. Instantiates the QDP plugin via ``trtp.op.torchtrt_ex.elementwise_scale_mul(...)``.\n c. Calls ``ctx.net.add_plugin(plugin, aot=False)`` to add a plugin layer to the\n TRT ``INetworkDefinition``.\n3. Registers the converter for ``torch.ops.torchtrt_ex.elementwise_scale_mul.default``\n in Torch-TensorRT's ``DYNAMO_CONVERTERS`` table via the\n ``@dynamo_tensorrt_converter`` decorator.\n\nFrom this point, whenever the compiler encounters that op in the FX graph, it will\ncall this converter and emit a plugin layer instead of a PyTorch fallback.\n\n``supports_dynamic_shapes=True`` tells the registry that this converter can handle\nsymbolic batch dimensions. ``requires_output_allocator=False`` means TRT knows the\noutput size at engine-build time (not data-dependent).\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "torch_tensorrt.dynamo.conversion.plugins.generate_plugin_converter(\n \"torchtrt_ex::elementwise_scale_mul\",\n supports_dynamic_shapes=True,\n requires_output_allocator=False,\n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The two calls above can be combined into one:\n\n```python\ntorch_tensorrt.dynamo.conversion.plugins.custom_op(\n \"torchtrt_ex::elementwise_scale_mul\",\n supports_dynamic_shapes=True,\n requires_output_allocator=False,\n)\n```\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 5: Compile and Run\n\nFrom here, compilation is identical to any other Torch-TensorRT model.\n``torch_tensorrt.compile`` will:\n\n* Export the model with ``torch.export``.\n* Partition the FX graph \u2014 the custom op node lands in a TRT subgraph because its\n converter is registered.\n* During TRT graph construction the converter is called, adding a plugin layer.\n* At inference time, TRT calls ``_generic_plugin_impl``, which invokes the Triton\n kernel on TRT's CUDA stream.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class MyModel(torch.nn.Module): # type: ignore[misc]\n def __init__(self):\n super().__init__()\n\n def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:\n z = torch.add(x, y)\n res = torch.ops.torchtrt_ex.elementwise_scale_mul.default(x, z, b=0.5)\n return res\n\n\nmy_model = MyModel().to(\"cuda\").eval()\nm = torch.randint(0, 5, (64, 64), device=\"cuda\", dtype=torch.float)\nn = torch.randint(0, 5, (64, 64), device=\"cuda\", dtype=torch.float)\n\nwith torch_tensorrt.logging.errors():\n model_trt = torch_tensorrt.compile(my_model, inputs=[m, n], min_block_size=1)\n with torch.no_grad():\n for i in range(300):\n res = model_trt(m, n)\n assert torch.allclose(res, my_model(m, n))\n\nprint(\"Ran with custom plugin!\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.15" } }, "nbformat": 4, "nbformat_minor": 0 }