Using Custom Kernels with NVRTC in TensorRT AOT Plugins#
This example demonstrates how to use the NVIDIA Runtime Compilation (NVRTC) library to compile custom CUDA kernels at runtime and integrate them into a TensorRT Ahead-Of-Time (AOT) plugin.
This approach is powerful because it allows you to: 1. Write raw CUDA C++ code for maximum performance. 2. Compile it on-the-fly, adapting to the specific GPU architecture. 3. Wrap it in a TensorRT plugin without writing a separate C++ plugin library. 4. Integrate it seamlessly into Torch-TensorRT’s compilation flow.
The example performs a simple pointwise Sigmoid operation: f(x) = 1 / (1 + exp(-x)).
[ ]:
from typing import List, Tuple, Union
import torch
import torch_tensorrt
# ============================================================================
# 1. Define the CUDA Kernel Source
# ============================================================================
# We define the CUDA kernel source code as a Python string.
# This code will be compiled by NVRTC.
# Note that we use extern "C" to avoid name mangling, making it easier to
# retrieve the kernel function by name later.
cu_code = """
// Simple pointwise Sigmoid kernel: f(x) = 1 / (1 + exp(-x))
extern "C" __global__ void pointwise_sigmoid_kernel_nvrtc(const float* __restrict__ input,
const int size,
float* __restrict__ output) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < size) {
const float x = input[idx];
// use fast device intrinsic to avoid headers
output[idx] = 1.0f / (1.0f + __expf(-x));
}
}
"""
# ============================================================================
# 2. Compile the Kernel using NVRTC (for eager mode)
# ============================================================================
# Before defining the Torch custom op, we compile the kernel so we can run it
# in standard PyTorch (eager mode) for verification and testing.
# We use the cuda-python library's NVRTC bindings.
from cuda.core.experimental import Device as _CudaDevice
from cuda.core.experimental import LaunchConfig as _LaunchConfig
from cuda.core.experimental import Program as _CudaProgram
from cuda.core.experimental import ProgramOptions as _CudaProgramOptions
from cuda.core.experimental import launch as _cuda_launch
# Initialize CUDA device and stream
_cuda_device = _CudaDevice()
_cuda_device.set_current()
_cuda_stream = _cuda_device.create_stream()
# Configure compilation options
_program_options = _CudaProgramOptions(
std="c++17",
arch=f"sm_{_cuda_device.arch}", # Target the current GPU architecture
include_path=["/usr/local/cuda/include"],
)
# Create and compile the program
_program = _CudaProgram(cu_code, code_type="c++", options=_program_options)
_module = _program.compile("ptx", name_expressions=("pointwise_sigmoid_kernel_nvrtc",))
_kernel = _module.get_kernel("pointwise_sigmoid_kernel_nvrtc")
# ============================================================================
# 3. Register Custom Op in PyTorch
# ============================================================================
# We register the custom operation with PyTorch so it can be used in models.
# The 'mutates_args=()' argument tells PyTorch this op is functional (doesn't modify inputs in-place).
@torch.library.custom_op("pointwise_sigmoid_ops::pointwise_sigmoid", mutates_args=()) # type: ignore[misc]
def pointwise_sigmoid(X: torch.Tensor) -> torch.Tensor:
"""
Implementation of the custom op for PyTorch eager execution.
This function launches the pre-compiled NVRTC kernel.
"""
assert X.is_cuda, "Tensor must be on CUDA device."
assert X.dtype == torch.float32, "For this test, expected float32 input"
Y = torch.empty_like(X)
N = int(X.numel())
block = 256
grid_x = max(1, (N + block - 1) // block)
config = _LaunchConfig(grid=(grid_x), block=(block))
# Helper class to wrap PyTorch's stream for cuda-python
class _PyTorchStreamWrapper:
def __init__(self, pt_stream):
self.pt_stream = pt_stream
def __cuda_stream__(self):
stream_id = self.pt_stream.cuda_stream
return (0, stream_id)
pt_stream = torch.cuda.current_stream()
s = _cuda_device.create_stream(_PyTorchStreamWrapper(pt_stream))
# Launch kernel with raw pointers
_cuda_launch(
s,
config,
_kernel,
X.data_ptr(),
N,
Y.data_ptr(),
)
return Y
# ============================================================================
# 4. Register Fake Implementation (Meta Kernel)
# ============================================================================
# The fake implementation is crucial for TorchDynamo. It tells the compiler
# about the output shape and data type without actually running the kernel.
# This is used during the tracing phase.
@torch.library.register_fake("pointwise_sigmoid_ops::pointwise_sigmoid")
def _(input: torch.Tensor) -> torch.Tensor:
"""Fake implementation for TorchDynamo tracing of base operation."""
return torch.empty_like(input)
# ============================================================================
# 5. Define TensorRT AOT Plugin
# ============================================================================
# Now we define how this operation should be handled within TensorRT.
# We use the torch_tensorrt plugin auto generation feature and the AOT implementation using NVRTC.
import tensorrt.plugin as trtp
from cuda.core.experimental import Device, LaunchConfig, Program, ProgramOptions
torch_tensorrt.dynamo.conversion.plugins.generate_plugin(
"pointwise_sigmoid_ops::pointwise_sigmoid"
)
# This is where the magic happens. We provide the compiled PTX code and
# launch parameters to TensorRT. This code runs during engine building.
@trtp.aot_impl("pointwise_sigmoid_ops::pointwise_sigmoid")
def sigmoid_aot_nvrtc_impl(
X: trtp.TensorDesc,
outputs: Tuple[trtp.TensorDesc],
tactic: int,
) -> Tuple[
Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs
]:
# Get the PTX code from our pre-compiled module
compiled_kernel = _module.code.decode("utf-8")
# Calculate grid and block dimensions based on input shape
N = X.shape_expr.numel()
launch_params = trtp.KernelLaunchParams()
block = 256
launch_params.grid_x = trtp.cdiv(N, block)
launch_params.block_x = block
launch_params.shared_mem = 0
# Pass the number of elements (N) as an extra argument to the kernel
extra_args = trtp.SymIntExprs(1)
extra_args[0] = trtp.SymInt32(N)
# Return: kernel name, PTX code, launch params, kernel arguments
return (
"pointwise_sigmoid_kernel_nvrtc",
compiled_kernel,
launch_params,
extra_args,
)
# ============================================================================
# 6. Generate Plugin Converter
# ============================================================================
# This registers the mapping between the PyTorch custom op and the TensorRT plugin.
# It tells Torch-TensorRT: "When you see 'pointwise_sigmoid_ops::pointwise_sigmoid',
# replace it with the TensorRT plugin we just defined."
torch_tensorrt.dynamo.conversion.plugins.generate_plugin_converter(
"pointwise_sigmoid_ops::pointwise_sigmoid",
supports_dynamic_shapes=True,
requires_output_allocator=False,
)
# ============================================================================
# 7. Test the Model
# ============================================================================
class PointwiseSigmoidModel_WithTRTWrapper(torch.nn.Module):
"""
Test model that uses the TRT wrapper with custom_op() registration.
"""
def forward(self, input: torch.Tensor) -> torch.Tensor:
z = torch.ops.pointwise_sigmoid_ops.pointwise_sigmoid(input)
return z
if __name__ == "__main__":
model = PointwiseSigmoidModel_WithTRTWrapper().to("cuda").eval()
input = torch.randn(1, 1024, device="cuda", dtype=torch.float32)
print("PyTorch baseline result:")
print(torch.sigmoid(input))
print("Custom Op eager result:")
print(model(input))
print("\nCompiling with Torch-TensorRT...")
with torch_tensorrt.logging.debug():
trt_inputs = [input]
model_trt = torch_tensorrt.compile(
model,
inputs=trt_inputs,
enabled_precisions={torch.float32},
min_block_size=1,
)
print("Model compiled successfully!")
print("Running inference with compiled model...")
with torch.no_grad():
for i in range(10):
res = model_trt(input)
assert torch.allclose(
res, model(input), rtol=1e-2, atol=1e-2
), "Results do not match!"
print("Inference successful!")