{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n\n# Using Custom Kernels with NVRTC in TensorRT AOT Plugins\n\nThis example demonstrates how to use the NVIDIA Runtime Compilation (NVRTC) library\nto compile custom CUDA kernels at runtime and integrate them into a TensorRT\nAhead-Of-Time (AOT) plugin.\n\nThis approach is powerful because it allows you to:\n1. Write raw CUDA C++ code for maximum performance.\n2. Compile it on-the-fly, adapting to the specific GPU architecture.\n3. Wrap it in a TensorRT plugin without writing a separate C++ plugin library.\n4. Integrate it seamlessly into Torch-TensorRT's compilation flow.\n\nThe example performs a simple pointwise Sigmoid operation: f(x) = 1 / (1 + exp(-x)).\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from typing import List, Tuple, Union\n\nimport torch\n\nimport torch_tensorrt\n\n# ============================================================================\n# 1. Define the CUDA Kernel Source\n# ============================================================================\n# We define the CUDA kernel source code as a Python string.\n# This code will be compiled by NVRTC.\n# Note that we use extern \"C\" to avoid name mangling, making it easier to\n# retrieve the kernel function by name later.\n\ncu_code = \"\"\"\n// Simple pointwise Sigmoid kernel: f(x) = 1 / (1 + exp(-x))\nextern \"C\" __global__ void pointwise_sigmoid_kernel_nvrtc(const float* __restrict__ input,\n const int size,\n float* __restrict__ output) {\n const int idx = blockIdx.x * blockDim.x + threadIdx.x;\n\n if (idx < size) {\n const float x = input[idx];\n // use fast device intrinsic to avoid headers\n output[idx] = 1.0f / (1.0f + __expf(-x));\n }\n}\n\"\"\"\n\n# ============================================================================\n# 2. Compile the Kernel using NVRTC (for eager mode)\n# ============================================================================\n# Before defining the Torch custom op, we compile the kernel so we can run it\n# in standard PyTorch (eager mode) for verification and testing.\n# We use the cuda-python library's NVRTC bindings.\n\nfrom cuda.core.experimental import Device as _CudaDevice\nfrom cuda.core.experimental import LaunchConfig as _LaunchConfig\nfrom cuda.core.experimental import Program as _CudaProgram\nfrom cuda.core.experimental import ProgramOptions as _CudaProgramOptions\nfrom cuda.core.experimental import launch as _cuda_launch\n\n# Initialize CUDA device and stream\n_cuda_device = _CudaDevice()\n_cuda_device.set_current()\n_cuda_stream = _cuda_device.create_stream()\n\n# Configure compilation options\n_program_options = _CudaProgramOptions(\n std=\"c++17\",\n arch=f\"sm_{_cuda_device.arch}\", # Target the current GPU architecture\n include_path=[\"/usr/local/cuda/include\"],\n)\n\n# Create and compile the program\n_program = _CudaProgram(cu_code, code_type=\"c++\", options=_program_options)\n_module = _program.compile(\"ptx\", name_expressions=(\"pointwise_sigmoid_kernel_nvrtc\",))\n_kernel = _module.get_kernel(\"pointwise_sigmoid_kernel_nvrtc\")\n\n\n# ============================================================================\n# 3. Register Custom Op in PyTorch\n# ============================================================================\n# We register the custom operation with PyTorch so it can be used in models.\n# The 'mutates_args=()' argument tells PyTorch this op is functional (doesn't modify inputs in-place).\n\n\n@torch.library.custom_op(\"pointwise_sigmoid_ops::pointwise_sigmoid\", mutates_args=()) # type: ignore[misc]\ndef pointwise_sigmoid(X: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Implementation of the custom op for PyTorch eager execution.\n This function launches the pre-compiled NVRTC kernel.\n \"\"\"\n assert X.is_cuda, \"Tensor must be on CUDA device.\"\n assert X.dtype == torch.float32, \"For this test, expected float32 input\"\n\n Y = torch.empty_like(X)\n N = int(X.numel())\n\n block = 256\n grid_x = max(1, (N + block - 1) // block)\n config = _LaunchConfig(grid=(grid_x), block=(block))\n\n # Helper class to wrap PyTorch's stream for cuda-python\n class _PyTorchStreamWrapper:\n def __init__(self, pt_stream):\n self.pt_stream = pt_stream\n\n def __cuda_stream__(self):\n stream_id = self.pt_stream.cuda_stream\n return (0, stream_id)\n\n pt_stream = torch.cuda.current_stream()\n s = _cuda_device.create_stream(_PyTorchStreamWrapper(pt_stream))\n\n # Launch kernel with raw pointers\n _cuda_launch(\n s,\n config,\n _kernel,\n X.data_ptr(),\n N,\n Y.data_ptr(),\n )\n\n return Y\n\n\n# ============================================================================\n# 4. Register Fake Implementation (Meta Kernel)\n# ============================================================================\n# The fake implementation is crucial for TorchDynamo. It tells the compiler\n# about the output shape and data type without actually running the kernel.\n# This is used during the tracing phase.\n\n\n@torch.library.register_fake(\"pointwise_sigmoid_ops::pointwise_sigmoid\")\ndef _(input: torch.Tensor) -> torch.Tensor:\n \"\"\"Fake implementation for TorchDynamo tracing of base operation.\"\"\"\n return torch.empty_like(input)\n\n\n# ============================================================================\n# 5. Define TensorRT AOT Plugin\n# ============================================================================\n# Now we define how this operation should be handled within TensorRT.\n# We use the torch_tensorrt plugin auto generation feature and the AOT implementation using NVRTC.\n\nimport tensorrt.plugin as trtp\nfrom cuda.core.experimental import Device, LaunchConfig, Program, ProgramOptions\n\ntorch_tensorrt.dynamo.conversion.plugins.generate_plugin(\n \"pointwise_sigmoid_ops::pointwise_sigmoid\"\n)\n\n\n# This is where the magic happens. We provide the compiled PTX code and\n# launch parameters to TensorRT. This code runs during engine building.\n@trtp.aot_impl(\"pointwise_sigmoid_ops::pointwise_sigmoid\")\ndef sigmoid_aot_nvrtc_impl(\n X: trtp.TensorDesc,\n outputs: Tuple[trtp.TensorDesc],\n tactic: int,\n) -> Tuple[\n Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs\n]:\n # Get the PTX code from our pre-compiled module\n compiled_kernel = _module.code.decode(\"utf-8\")\n\n # Calculate grid and block dimensions based on input shape\n N = X.shape_expr.numel()\n launch_params = trtp.KernelLaunchParams()\n block = 256\n launch_params.grid_x = trtp.cdiv(N, block)\n launch_params.block_x = block\n launch_params.shared_mem = 0\n\n # Pass the number of elements (N) as an extra argument to the kernel\n extra_args = trtp.SymIntExprs(1)\n extra_args[0] = trtp.SymInt32(N)\n\n # Return: kernel name, PTX code, launch params, kernel arguments\n return (\n \"pointwise_sigmoid_kernel_nvrtc\",\n compiled_kernel,\n launch_params,\n extra_args,\n )\n\n\n# ============================================================================\n# 6. Generate Plugin Converter\n# ============================================================================\n# This registers the mapping between the PyTorch custom op and the TensorRT plugin.\n# It tells Torch-TensorRT: \"When you see 'pointwise_sigmoid_ops::pointwise_sigmoid',\n# replace it with the TensorRT plugin we just defined.\"\n\ntorch_tensorrt.dynamo.conversion.plugins.generate_plugin_converter(\n \"pointwise_sigmoid_ops::pointwise_sigmoid\",\n supports_dynamic_shapes=True,\n requires_output_allocator=False,\n)\n\n\n# ============================================================================\n# 7. Test the Model\n# ============================================================================\n\n\nclass PointwiseSigmoidModel_WithTRTWrapper(torch.nn.Module):\n \"\"\"\n Test model that uses the TRT wrapper with custom_op() registration.\n \"\"\"\n\n def forward(self, input: torch.Tensor) -> torch.Tensor:\n z = torch.ops.pointwise_sigmoid_ops.pointwise_sigmoid(input)\n return z\n\n\nif __name__ == \"__main__\":\n model = PointwiseSigmoidModel_WithTRTWrapper().to(\"cuda\").eval()\n input = torch.randn(1, 1024, device=\"cuda\", dtype=torch.float32)\n\n print(\"PyTorch baseline result:\")\n print(torch.sigmoid(input))\n\n print(\"Custom Op eager result:\")\n print(model(input))\n\n print(\"\\nCompiling with Torch-TensorRT...\")\n with torch_tensorrt.logging.debug():\n trt_inputs = [input]\n model_trt = torch_tensorrt.compile(\n model,\n inputs=trt_inputs,\n enabled_precisions={torch.float32},\n min_block_size=1,\n )\n print(\"Model compiled successfully!\")\n\n print(\"Running inference with compiled model...\")\n with torch.no_grad():\n for i in range(10):\n res = model_trt(input)\n assert torch.allclose(\n res, model(input), rtol=1e-2, atol=1e-2\n ), \"Results do not match!\"\n\n print(\"Inference successful!\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.15" } }, "nbformat": 4, "nbformat_minor": 0 }