{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n\n# Automatically Generate a Converter for a Custom Kernel\n\nWe are going to demonstrate how to automatically generate a converter for a custom kernel using Torch-TensorRT using\nthe new Python based plugin system in TensorRT 10.8.\n\nTorch-TensorRT supports falling back to PyTorch implementations of operations in the case that Torch-TensorRT\ndoes not know how to compile them in TensorRT. However, this comes at the cost of a graph break and will reduce the performance of the model.\nThe easiest way to fix lack of support for ops is by adding a decomposition (see:\n[Writing lowering passes for the Dynamo frontend](https://pytorch.org/TensorRT/contributors/writing_dynamo_aten_lowering_passes.html)) - which defines the operator\nin terms of PyTorch ops that are supported in Torch-TensorRT or a converter (see:\n[Writing converters for the Dynamo frontend](https://pytorch.org/TensorRT/contributors/dynamo_converters.html)) - which defines the operator in terms of TensorRT operators.\n\nIn some cases there isn't a great way to do either of these, perhaps because the operator is a custom kernel that is not part of standard PyTorch or\nTensorRT cannot support it natively.\n\nFor these cases, it is possible to use a TensorRT plugin to replace the operator **inside** the TensorRT engine, thereby avoiding\nthe performance and resource overhead from a graph break.\n\nPreviously this involved a complex process in not only building a performant kernel but setting it up to run in TensorRT (see: [Using Custom Kernels within TensorRT Engines with Torch-TensorRT](https://pytorch.org/TensorRT/tutorials/_rendered_examples/dynamo/custom_kernel_plugins.html)).\nWith TensorRT 10.8, there is a new Python native plugin system which greatly streamlines this process. This\nplugin system also allows Torch-TensorRT to automatically generate the necessary conversion code to convert the\noperation in PyTorch to TensorRT.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Writing Custom Operators in PyTorch\n\n Pervious tutorials already cover creating custom operators in PyTorch which later get used with Torch-TensorRT.\nHere we define a simple elementwise multiplication operator in Triton. This operator is then registered as a custom op in PyTorch.\nwith its host launch code as well as a \"meta-kernel\", A meta-kernel is a function that describes the shape and data type\ntransformations that the operator will perform. This meta-kernel is used by Dynamo and Torch-TensorRT, so it\nis necessary to define.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from typing import Tuple\n\nimport tensorrt.plugin as trtp\nimport torch\nimport torch_tensorrt\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef elementwise_mul_kernel(X, Y, Z, BLOCK_SIZE: tl.constexpr):\n # Program ID determines the block of data each thread will process\n pid = tl.program_id(0)\n # Compute the range of elements that this thread block will work on\n block_start = pid * BLOCK_SIZE\n # Range of indices this thread will handle\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n # Load elements from the X and Y tensors\n x_vals = tl.load(X + offsets)\n y_vals = tl.load(Y + offsets)\n # Perform the element-wise multiplication\n z_vals = x_vals * y_vals\n # Store the result in Z\n tl.store(Z + offsets, z_vals)\n\n\n@torch.library.custom_op(\"torchtrt_ex::elementwise_mul\", mutates_args=()) # type: ignore[misc]\ndef elementwise_mul(\n X: torch.Tensor, Y: torch.Tensor, b: float = 0.2, a: int = 2\n) -> torch.Tensor:\n # Ensure the tensors are on the GPU\n assert X.is_cuda and Y.is_cuda, \"Tensors must be on CUDA device.\"\n assert X.shape == Y.shape, \"Tensors must have the same shape.\"\n\n # Create output tensor\n Z = torch.empty_like(X)\n\n # Define block size\n BLOCK_SIZE = 1024\n\n # Grid of programs\n grid = lambda meta: (X.numel() // meta[\"BLOCK_SIZE\"],)\n\n # Launch the kernel\n elementwise_mul_kernel[grid](X, Y, Z, BLOCK_SIZE=BLOCK_SIZE)\n\n return Z" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The meta kernel for an elementwise operation is just the shape and dtype of one of the inputs since we will not change the shape\nin the course of the operation.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "@torch.library.register_fake(\"torchtrt_ex::elementwise_mul\")\ndef _(x: torch.Tensor, y: torch.Tensor, b: float = 0.2, a: int = 2) -> torch.Tensor:\n return x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Writing Plugins for TensorRT using the Quick Deploy Plugin system\nThe quick deployment plugin system in TensorRT 10.8 allows for the creation of custom plugins in Python with significantly\nless boilerplate. It uses a similar system PyTorch where you define a function that describes the shape and data type transformations\nthat the operator will perform and then define the code to launch the kernel given GPU memory handles.\n\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Just like the PyTorch meta kernel, there is no transformation in shape or data type between the input and output so\nwe can just tell TensorRT to expect the same shape as we get in\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "@trtp.register(\"torchtrt_ex::elementwise_mul\")\ndef _(\n x: trtp.TensorDesc, y: trtp.TensorDesc, b: float, a: int\n) -> Tuple[trtp.TensorDesc]:\n return x.like()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here we reuse similar host launch code as PyTorch but we need to convert the TensorRT tensors into PyTorch tensors prior to launching the kernel\nThese operations are also in-place, so the result must be put in the the output tensors provided by TensorRT.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "@trtp.impl(\"torchtrt_ex::elementwise_mul\")\ndef _(\n x: trtp.Tensor,\n y: trtp.Tensor,\n b: float,\n a: int,\n outputs: Tuple[trtp.Tensor],\n stream: int,\n):\n # Define block size\n BLOCK_SIZE = 1024\n\n # Grid of programs\n grid = lambda meta: (x.numel() // meta[\"BLOCK_SIZE\"],)\n\n x_t = torch.as_tensor(x, device=\"cuda\")\n y_t = torch.as_tensor(y, device=\"cuda\")\n z_t = torch.as_tensor(outputs[0], device=\"cuda\")\n # Launch the kernel\n elementwise_mul_kernel[grid](x_t, y_t, z_t, BLOCK_SIZE=BLOCK_SIZE)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generating the Converter\nGiven that we have defined the custom operator in PyTorch and TensorRT, we can now generate the converter for the operation.\nAs long as the namespace and names match, the following function will automatically generate the converter for the operation.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "torch_tensorrt.dynamo.conversion.plugins.generate_plugin_converter(\n \"torchtrt_ex::elementwise_mul\", supports_dynamic_shapes=True\n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Using our converter with a model\n\nNow we can use our custom operator in a model and compile it with Torch-TensorRT.\nWe can see that the custom operator is used as one of the operations in the forward pass of the model.\nThe process of compiling the model at this point is identical to standard Torch-TensorRT usage.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class MyModel(torch.nn.Module): # type: ignore[misc]\n def __init__(self):\n super().__init__()\n\n def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:\n z = torch.add(x, y)\n res = torch.ops.torchtrt_ex.elementwise_mul.default(x, z, a=1)\n\n return res\n\n\nmy_model = MyModel().to(\"cuda\").eval()\nm = torch.full((64, 64), 2, device=\"cuda\", dtype=torch.float)\nn = torch.full((64, 64), 3, device=\"cuda\", dtype=torch.float)\n\nwith torch_tensorrt.logging.errors():\n model_trt = torch_tensorrt.compile(my_model, inputs=[m, n], min_block_size=1)\n with torch.no_grad():\n for i in range(300):\n res = model_trt(m, n)\n assert torch.allclose(res, my_model(m, n))\n\nprint(\"Ran with custom plugin!\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.15" } }, "nbformat": 4, "nbformat_minor": 0 }