.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorials/_rendered_examples/dynamo/nvrtc_aot_plugin.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials__rendered_examples_dynamo_nvrtc_aot_plugin.py: Using Custom Kernels with NVRTC in TensorRT AOT Plugins ======================================================= This example demonstrates how to use the NVIDIA Runtime Compilation (NVRTC) library to compile custom CUDA kernels at runtime and integrate them into a TensorRT Ahead-Of-Time (AOT) plugin. This approach is powerful because it allows you to: 1. Write raw CUDA C++ code for maximum performance. 2. Compile it on-the-fly, adapting to the specific GPU architecture. 3. Wrap it in a TensorRT plugin without writing a separate C++ plugin library. 4. Integrate it seamlessly into Torch-TensorRT's compilation flow. The example performs a simple pointwise Sigmoid operation: f(x) = 1 / (1 + exp(-x)). .. GENERATED FROM PYTHON SOURCE LINES 17-247 .. code-block:: python from typing import List, Tuple, Union import torch import torch_tensorrt # ============================================================================ # 1. Define the CUDA Kernel Source # ============================================================================ # We define the CUDA kernel source code as a Python string. # This code will be compiled by NVRTC. # Note that we use extern "C" to avoid name mangling, making it easier to # retrieve the kernel function by name later. cu_code = """ // Simple pointwise Sigmoid kernel: f(x) = 1 / (1 + exp(-x)) extern "C" __global__ void pointwise_sigmoid_kernel_nvrtc(const float* __restrict__ input, const int size, float* __restrict__ output) { const int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < size) { const float x = input[idx]; // use fast device intrinsic to avoid headers output[idx] = 1.0f / (1.0f + __expf(-x)); } } """ # ============================================================================ # 2. Compile the Kernel using NVRTC (for eager mode) # ============================================================================ # Before defining the Torch custom op, we compile the kernel so we can run it # in standard PyTorch (eager mode) for verification and testing. # We use the cuda-python library's NVRTC bindings. from cuda.core.experimental import Device as _CudaDevice from cuda.core.experimental import LaunchConfig as _LaunchConfig from cuda.core.experimental import Program as _CudaProgram from cuda.core.experimental import ProgramOptions as _CudaProgramOptions from cuda.core.experimental import launch as _cuda_launch # Initialize CUDA device and stream _cuda_device = _CudaDevice() _cuda_device.set_current() _cuda_stream = _cuda_device.create_stream() # Configure compilation options _program_options = _CudaProgramOptions( std="c++17", arch=f"sm_{_cuda_device.arch}", # Target the current GPU architecture include_path=["/usr/local/cuda/include"], ) # Create and compile the program _program = _CudaProgram(cu_code, code_type="c++", options=_program_options) _module = _program.compile("ptx", name_expressions=("pointwise_sigmoid_kernel_nvrtc",)) _kernel = _module.get_kernel("pointwise_sigmoid_kernel_nvrtc") # ============================================================================ # 3. Register Custom Op in PyTorch # ============================================================================ # We register the custom operation with PyTorch so it can be used in models. # The 'mutates_args=()' argument tells PyTorch this op is functional (doesn't modify inputs in-place). @torch.library.custom_op("pointwise_sigmoid_ops::pointwise_sigmoid", mutates_args=()) # type: ignore[misc] def pointwise_sigmoid(X: torch.Tensor) -> torch.Tensor: """ Implementation of the custom op for PyTorch eager execution. This function launches the pre-compiled NVRTC kernel. """ assert X.is_cuda, "Tensor must be on CUDA device." assert X.dtype == torch.float32, "For this test, expected float32 input" Y = torch.empty_like(X) N = int(X.numel()) block = 256 grid_x = max(1, (N + block - 1) // block) config = _LaunchConfig(grid=(grid_x), block=(block)) # Helper class to wrap PyTorch's stream for cuda-python class _PyTorchStreamWrapper: def __init__(self, pt_stream): self.pt_stream = pt_stream def __cuda_stream__(self): stream_id = self.pt_stream.cuda_stream return (0, stream_id) pt_stream = torch.cuda.current_stream() s = _cuda_device.create_stream(_PyTorchStreamWrapper(pt_stream)) # Launch kernel with raw pointers _cuda_launch( s, config, _kernel, X.data_ptr(), N, Y.data_ptr(), ) return Y # ============================================================================ # 4. Register Fake Implementation (Meta Kernel) # ============================================================================ # The fake implementation is crucial for TorchDynamo. It tells the compiler # about the output shape and data type without actually running the kernel. # This is used during the tracing phase. @torch.library.register_fake("pointwise_sigmoid_ops::pointwise_sigmoid") def _(input: torch.Tensor) -> torch.Tensor: """Fake implementation for TorchDynamo tracing of base operation.""" return torch.empty_like(input) # ============================================================================ # 5. Define TensorRT AOT Plugin # ============================================================================ # Now we define how this operation should be handled within TensorRT. # We use the torch_tensorrt plugin auto generation feature and the AOT implementation using NVRTC. import tensorrt.plugin as trtp from cuda.core.experimental import Device, LaunchConfig, Program, ProgramOptions torch_tensorrt.dynamo.conversion.plugins.generate_plugin( "pointwise_sigmoid_ops::pointwise_sigmoid" ) # This is where the magic happens. We provide the compiled PTX code and # launch parameters to TensorRT. This code runs during engine building. @trtp.aot_impl("pointwise_sigmoid_ops::pointwise_sigmoid") def sigmoid_aot_nvrtc_impl( X: trtp.TensorDesc, outputs: Tuple[trtp.TensorDesc], tactic: int, ) -> Tuple[ Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs ]: # Get the PTX code from our pre-compiled module compiled_kernel = _module.code.decode("utf-8") # Calculate grid and block dimensions based on input shape N = X.shape_expr.numel() launch_params = trtp.KernelLaunchParams() block = 256 launch_params.grid_x = trtp.cdiv(N, block) launch_params.block_x = block launch_params.shared_mem = 0 # Pass the number of elements (N) as an extra argument to the kernel extra_args = trtp.SymIntExprs(1) extra_args[0] = trtp.SymInt32(N) # Return: kernel name, PTX code, launch params, kernel arguments return ( "pointwise_sigmoid_kernel_nvrtc", compiled_kernel, launch_params, extra_args, ) # ============================================================================ # 6. Generate Plugin Converter # ============================================================================ # This registers the mapping between the PyTorch custom op and the TensorRT plugin. # It tells Torch-TensorRT: "When you see 'pointwise_sigmoid_ops::pointwise_sigmoid', # replace it with the TensorRT plugin we just defined." torch_tensorrt.dynamo.conversion.plugins.generate_plugin_converter( "pointwise_sigmoid_ops::pointwise_sigmoid", supports_dynamic_shapes=True, requires_output_allocator=False, ) # ============================================================================ # 7. Test the Model # ============================================================================ class PointwiseSigmoidModel_WithTRTWrapper(torch.nn.Module): """ Test model that uses the TRT wrapper with custom_op() registration. """ def forward(self, input: torch.Tensor) -> torch.Tensor: z = torch.ops.pointwise_sigmoid_ops.pointwise_sigmoid(input) return z if __name__ == "__main__": model = PointwiseSigmoidModel_WithTRTWrapper().to("cuda").eval() input = torch.randn(1, 1024, device="cuda", dtype=torch.float32) print("PyTorch baseline result:") print(torch.sigmoid(input)) print("Custom Op eager result:") print(model(input)) print("\nCompiling with Torch-TensorRT...") with torch_tensorrt.logging.debug(): trt_inputs = [input] model_trt = torch_tensorrt.compile( model, inputs=trt_inputs, enabled_precisions={torch.float32}, min_block_size=1, ) print("Model compiled successfully!") print("Running inference with compiled model...") with torch.no_grad(): for i in range(10): res = model_trt(input) assert torch.allclose( res, model(input), rtol=1e-2, atol=1e-2 ), "Results do not match!" print("Inference successful!") .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.000 seconds) .. _sphx_glr_download_tutorials__rendered_examples_dynamo_nvrtc_aot_plugin.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: nvrtc_aot_plugin.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: nvrtc_aot_plugin.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_