{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\nAutomatically Generate a TensorRT AOT Plugin\n===================================================================\n\nThis example builds on `auto_generate_plugins` by showing the *AOT* (Ahead-of-Time)\nplugin path. Instead of calling back into Python at TRT engine runtime (JIT), the\nTriton kernel is compiled to PTX at *plugin registration time* and the binary is\nembedded in the TRT engine. This eliminates all Python overhead during inference.\n\n**JIT vs AOT \u2014 the key difference:**\n\n* **JIT plugin** (``generate_plugin`` default): TRT holds a Python callback. At runtime\n it converts TRT tensor handles to ``torch.Tensor``, calls your op, copies results\n back. Simple, but adds Python overhead per inference call.\n\n* **AOT plugin** (this example): At registration time ``@trtp.aot_impl`` compiles the\n Triton kernel to PTX/CUBIN and returns the binary plus kernel launch parameters.\n TRT embeds that binary in the engine. At runtime TRT launches the kernel directly \u2014\n no Python, no tensor conversion, no copying. Also required for serialized engines\n that will run without a Python environment (e.g. C++ deployment).\n\n**When to use AOT:**\n\n* Performance-critical inference paths.\n* Engines that must be serialized and loaded in C++.\n* Any case where you need ``use_aot_if_available=True`` and want the guarantee that\n the AOT path is actually taken.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import argparse\nfrom typing import Tuple, Union\n\nimport tensorrt as trt\nimport tensorrt.plugin as trtp\nimport torch\nimport torch_tensorrt\nimport triton\nimport triton.language as tl\n\ntrt_logger = trt.Logger(trt.Logger.VERBOSE)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Step 1: Define the Triton Kernel\n\nSame as the JIT example \u2014 the kernel is pure Triton. The difference is how it\ngets compiled: in the JIT path ``add_one_kernel[grid](...)`` is called at runtime;\nin the AOT path it is compiled to PTX inside ``@trtp.aot_impl`` below.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "@triton.jit\ndef add_one_kernel(x_ptr, n_elements, y_ptr, BLOCK_SIZE: tl.constexpr):\n pid = tl.program_id(0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(x_ptr + offsets, mask=mask)\n output = x + 1\n tl.store(y_ptr + offsets, output, mask=mask)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Step 2: Register the PyTorch op\n\nIdentical to the JIT example. The meta-kernel (``register_fake``) is still needed:\nTRT uses the shape-descriptor from ``@trtp.register`` (below) for graph-build-time\nshape inference, but Dynamo's tracing and Torch-TensorRT's partitioner still need\nthe fake kernel.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "@torch.library.custom_op(\"my::add_one\", mutates_args=()) # type: ignore[misc]\ndef add_one(X: torch.Tensor) -> torch.Tensor:\n assert X.is_cuda\n Y = torch.empty_like(X)\n BLOCK_SIZE = 256\n grid = lambda meta: (triton.cdiv(X.numel(), meta[\"BLOCK_SIZE\"]),)\n add_one_kernel[grid](X, X.numel(), Y, BLOCK_SIZE=BLOCK_SIZE)\n return Y\n\n\n@torch.library.register_fake(\"my::add_one\")\ndef _(X: torch.Tensor) -> torch.Tensor:\n return X" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Step 3: Register the QDP Shape Descriptor\n\n``@trtp.register`` manually registers the plugin *shape descriptor* in TensorRT's\n``QDP_REGISTRY`` under the key ``\"my::add_one\"``. This is different from\n``generate_plugin``, which auto-generates the descriptor from the fake kernel.\n\nThe function receives ``trtp.TensorDesc`` objects describing input shapes/dtypes,\nand must return a tuple of ``trtp.TensorDesc`` for outputs.\n``X.like()`` means \"same shape and dtype as X\" \u2014 shorthand for elementwise ops.\n\nRegistering manually here (instead of calling ``generate_plugin``) is required for\nAOT plugins because we need to associate our own ``@trtp.aot_impl`` with the plugin\nentry. ``generate_plugin`` would create its own JIT impl and close the entry.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "@trtp.register(\"my::add_one\")\ndef add_plugin_desc(X: trtp.TensorDesc) -> Tuple[trtp.TensorDesc]:\n # Output has the same shape and dtype as the input.\n return X.like()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Step 4: Register the AOT Implementation\n\n``@trtp.aot_impl`` is called **once at registration time** (not at inference time).\nIt must compile the kernel to a binary and return everything TRT needs to launch it:\n\n* ``compiled_kernel.metadata.name`` \u2014 the kernel function name in the PTX/CUBIN.\n* ``compiled_kernel.asm[\"ptx\"]`` \u2014 the PTX source string (or CUBIN bytes).\n TRT embeds this binary in the serialized engine.\n* ``launch_params`` \u2014 grid/block dims and shared memory. These can be symbolic\n (using ``trtp.SymExprs``) so the same engine works across batch sizes.\n* ``extra_args`` \u2014 additional scalar arguments passed at launch. Here ``N`` (number\n of elements) is a ``SymInt32`` that TRT evaluates from the actual input shape at\n runtime.\n\nTRT stores the compiled binary in ``QDP_REGISTRY[\"my::add_one\"].aot_impl_func``.\nWhen ``generate_plugin_converter`` is later called with ``use_aot_if_available=True``\nit detects ``aot_impl_func is not None`` and sets ``aot=True`` on the plugin layer,\ncausing TRT to use the binary path instead of a Python callback.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "@trtp.aot_impl(\"my::add_one\")\ndef add_plugin_aot_impl(\n X: trtp.TensorDesc, outputs: Tuple[trtp.TensorDesc], tactic: int\n) -> Tuple[\n Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs\n]:\n # Choose the pointer type based on the input dtype.\n type_str = \"fp32\" if X.dtype == trt.float32 else \"fp16\"\n\n block_size = 256\n # Compile the Triton kernel to PTX now, at registration time.\n # ``ASTSource`` describes the kernel's input types and constexprs without\n # running it \u2014 Triton compiles it to architecture-specific PTX/CUBIN.\n src = triton.compiler.ASTSource(\n fn=add_one_kernel,\n signature={\n \"x_ptr\": f\"*{type_str}\",\n \"n_elements\": \"i32\",\n \"y_ptr\": f\"*{type_str}\",\n },\n constexprs={\n \"BLOCK_SIZE\": block_size,\n },\n )\n compiled_kernel = triton.compile(src)\n\n # Build symbolic launch parameters.\n # ``X.shape_expr.numel()`` is a symbolic expression for the total number of\n # elements \u2014 TRT will evaluate it to a concrete integer at engine runtime.\n N = X.shape_expr.numel()\n launch_params = trtp.KernelLaunchParams()\n launch_params.grid_x = trtp.cdiv(N, block_size) # number of thread blocks\n launch_params.block_x = compiled_kernel.metadata.num_warps * 32 # threads per block\n launch_params.shared_mem = compiled_kernel.metadata.shared # bytes of shared mem\n\n # ``extra_args`` are scalar arguments appended to the kernel's argument list at\n # launch. Here ``n_elements`` is passed as a 32-bit symbolic integer so TRT\n # evaluates it from the actual tensor size at runtime.\n extra_args = trtp.SymIntExprs(1)\n extra_args[0] = trtp.SymInt32(N)\n\n return (\n compiled_kernel.metadata.name, # kernel function name in PTX\n compiled_kernel.asm[\"ptx\"], # PTX source \u2014 embedded in TRT engine\n launch_params,\n extra_args,\n )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Step 5: Generate the Converter\n\nUnlike the JIT example, we do **not** call ``generate_plugin`` here \u2014 the shape\ndescriptor and AOT impl are already registered manually above.\nWe only need the converter that bridges the Torch op to the TRT network layer.\n\n``generate_plugin_converter`` finds ``\"my::add_one\"`` in ``QDP_REGISTRY``, sees\nthat ``aot_impl_func is not None``, and creates a converter that calls\n``ctx.net.add_plugin(plugin, aot=True)``. The ``aot=True`` flag instructs TRT to\nuse the pre-compiled PTX rather than a Python JIT callback at runtime.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "torch_tensorrt.dynamo.conversion.plugins.generate_plugin_converter(\n \"my::add_one\",\n supports_dynamic_shapes=False,\n requires_output_allocator=False,\n use_aot_if_available=True,\n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Step 6: Compile and Run\n\nCompilation is identical to the JIT example. The difference is what happens at\ninference time: TRT launches the pre-compiled PTX kernel directly on the GPU with\nno Python involvement.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class MyModel(torch.nn.Module):\n def __init__(self):\n super().__init__()\n\n def forward(self, X: torch.Tensor) -> torch.Tensor:\n res = torch.ops.my.add_one.default(X)\n return res\n\n\nif __name__ == \"__main__\":\n parser = argparse.ArgumentParser()\n parser.add_argument(\n \"--aot\", action=\"store_true\", help=\"Try to use AOT compilation\", default=False\n )\n args = parser.parse_args()\n\n my_model = MyModel().to(\"cuda\").eval()\n m = torch.full((64, 64), 2, device=\"cuda\", dtype=torch.float)\n\n assert my_model(X=m)[0][0] == 3.0\n\n with torch_tensorrt.logging.debug():\n trt_inputs = [m]\n model_trt = torch_tensorrt.compile(\n my_model,\n inputs=trt_inputs,\n min_block_size=1,\n )\n print(\"Model compiled successfully!\")\n print(\"Running inference with compiled model...\")\n with torch.no_grad():\n for i in range(10):\n res = model_trt(m)\n assert torch.allclose(res, my_model(m)), \"Results do not match!\"\n\n print(\"Inference successful!\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.14" } }, "nbformat": 4, "nbformat_minor": 0 }