Using Custom Kernels with NVRTC in TensorRT AOT Plugins#

This example demonstrates how to use the NVIDIA Runtime Compilation (NVRTC) library to compile custom CUDA kernels at runtime and integrate them into a TensorRT Ahead-Of-Time (AOT) plugin.

This approach is powerful because it allows you to: 1. Write raw CUDA C++ code for maximum performance. 2. Compile it on-the-fly, adapting to the specific GPU architecture. 3. Wrap it in a TensorRT plugin without writing a separate C++ plugin library. 4. Integrate it seamlessly into Torch-TensorRT’s compilation flow.

The example performs a simple pointwise Sigmoid operation: f(x) = 1 / (1 + exp(-x)).

[ ]:
from typing import List, Tuple, Union

import torch

import torch_tensorrt

# ============================================================================
# 1. Define the CUDA Kernel Source
# ============================================================================
# We define the CUDA kernel source code as a Python string.
# This code will be compiled by NVRTC.
# Note that we use extern "C" to avoid name mangling, making it easier to
# retrieve the kernel function by name later.

cu_code = """
// Simple pointwise Sigmoid kernel: f(x) = 1 / (1 + exp(-x))
extern "C" __global__ void pointwise_sigmoid_kernel_nvrtc(const float* __restrict__ input,
                                                          const int size,
                                                          float* __restrict__ output) {
    const int idx = blockIdx.x * blockDim.x + threadIdx.x;

    if (idx < size) {
        const float x = input[idx];
        // use fast device intrinsic to avoid headers
        output[idx] = 1.0f / (1.0f + __expf(-x));
    }
}
"""

# ============================================================================
# 2. Compile the Kernel using NVRTC (for eager mode)
# ============================================================================
# Before defining the Torch custom op, we compile the kernel so we can run it
# in standard PyTorch (eager mode) for verification and testing.
# We use the cuda-python library's NVRTC bindings.

from cuda.core.experimental import Device as _CudaDevice
from cuda.core.experimental import LaunchConfig as _LaunchConfig
from cuda.core.experimental import Program as _CudaProgram
from cuda.core.experimental import ProgramOptions as _CudaProgramOptions
from cuda.core.experimental import launch as _cuda_launch

# Initialize CUDA device and stream
_cuda_device = _CudaDevice()
_cuda_device.set_current()
_cuda_stream = _cuda_device.create_stream()

# Configure compilation options
_program_options = _CudaProgramOptions(
    std="c++17",
    arch=f"sm_{_cuda_device.arch}",  # Target the current GPU architecture
    include_path=["/usr/local/cuda/include"],
)

# Create and compile the program
_program = _CudaProgram(cu_code, code_type="c++", options=_program_options)
_module = _program.compile("ptx", name_expressions=("pointwise_sigmoid_kernel_nvrtc",))
_kernel = _module.get_kernel("pointwise_sigmoid_kernel_nvrtc")


# ============================================================================
# 3. Register Custom Op in PyTorch
# ============================================================================
# We register the custom operation with PyTorch so it can be used in models.
# The 'mutates_args=()' argument tells PyTorch this op is functional (doesn't modify inputs in-place).


@torch.library.custom_op("pointwise_sigmoid_ops::pointwise_sigmoid", mutates_args=())  # type: ignore[misc]
def pointwise_sigmoid(X: torch.Tensor) -> torch.Tensor:
    """
    Implementation of the custom op for PyTorch eager execution.
    This function launches the pre-compiled NVRTC kernel.
    """
    assert X.is_cuda, "Tensor must be on CUDA device."
    assert X.dtype == torch.float32, "For this test, expected float32 input"

    Y = torch.empty_like(X)
    N = int(X.numel())

    block = 256
    grid_x = max(1, (N + block - 1) // block)
    config = _LaunchConfig(grid=(grid_x), block=(block))

    # Helper class to wrap PyTorch's stream for cuda-python
    class _PyTorchStreamWrapper:
        def __init__(self, pt_stream):
            self.pt_stream = pt_stream

        def __cuda_stream__(self):
            stream_id = self.pt_stream.cuda_stream
            return (0, stream_id)

    pt_stream = torch.cuda.current_stream()
    s = _cuda_device.create_stream(_PyTorchStreamWrapper(pt_stream))

    # Launch kernel with raw pointers
    _cuda_launch(
        s,
        config,
        _kernel,
        X.data_ptr(),
        N,
        Y.data_ptr(),
    )

    return Y


# ============================================================================
# 4. Register Fake Implementation (Meta Kernel)
# ============================================================================
# The fake implementation is crucial for TorchDynamo. It tells the compiler
# about the output shape and data type without actually running the kernel.
# This is used during the tracing phase.


@torch.library.register_fake("pointwise_sigmoid_ops::pointwise_sigmoid")
def _(input: torch.Tensor) -> torch.Tensor:
    """Fake implementation for TorchDynamo tracing of base operation."""
    return torch.empty_like(input)


# ============================================================================
# 5. Define TensorRT AOT Plugin
# ============================================================================
# Now we define how this operation should be handled within TensorRT.
# We use the torch_tensorrt plugin auto generation feature and the AOT implementation using NVRTC.

import tensorrt.plugin as trtp
from cuda.core.experimental import Device, LaunchConfig, Program, ProgramOptions

torch_tensorrt.dynamo.conversion.plugins.generate_plugin(
    "pointwise_sigmoid_ops::pointwise_sigmoid"
)


# This is where the magic happens. We provide the compiled PTX code and
# launch parameters to TensorRT. This code runs during engine building.
@trtp.aot_impl("pointwise_sigmoid_ops::pointwise_sigmoid")
def sigmoid_aot_nvrtc_impl(
    X: trtp.TensorDesc,
    outputs: Tuple[trtp.TensorDesc],
    tactic: int,
) -> Tuple[
    Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs
]:
    # Get the PTX code from our pre-compiled module
    compiled_kernel = _module.code.decode("utf-8")

    # Calculate grid and block dimensions based on input shape
    N = X.shape_expr.numel()
    launch_params = trtp.KernelLaunchParams()
    block = 256
    launch_params.grid_x = trtp.cdiv(N, block)
    launch_params.block_x = block
    launch_params.shared_mem = 0

    # Pass the number of elements (N) as an extra argument to the kernel
    extra_args = trtp.SymIntExprs(1)
    extra_args[0] = trtp.SymInt32(N)

    # Return: kernel name, PTX code, launch params, kernel arguments
    return (
        "pointwise_sigmoid_kernel_nvrtc",
        compiled_kernel,
        launch_params,
        extra_args,
    )


# ============================================================================
# 6. Generate Plugin Converter
# ============================================================================
# This registers the mapping between the PyTorch custom op and the TensorRT plugin.
# It tells Torch-TensorRT: "When you see 'pointwise_sigmoid_ops::pointwise_sigmoid',
# replace it with the TensorRT plugin we just defined."

torch_tensorrt.dynamo.conversion.plugins.generate_plugin_converter(
    "pointwise_sigmoid_ops::pointwise_sigmoid",
    supports_dynamic_shapes=True,
    requires_output_allocator=False,
)


# ============================================================================
# 7. Test the Model
# ============================================================================


class PointwiseSigmoidModel_WithTRTWrapper(torch.nn.Module):
    """
    Test model that uses the TRT wrapper with custom_op() registration.
    """

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        z = torch.ops.pointwise_sigmoid_ops.pointwise_sigmoid(input)
        return z


if __name__ == "__main__":
    model = PointwiseSigmoidModel_WithTRTWrapper().to("cuda").eval()
    input = torch.randn(1, 1024, device="cuda", dtype=torch.float32)

    print("PyTorch baseline result:")
    print(torch.sigmoid(input))

    print("Custom Op eager result:")
    print(model(input))

    print("\nCompiling with Torch-TensorRT...")
    with torch_tensorrt.logging.debug():
        trt_inputs = [input]
        model_trt = torch_tensorrt.compile(
            model,
            inputs=trt_inputs,
            enabled_precisions={torch.float32},
            min_block_size=1,
        )
        print("Model compiled successfully!")

        print("Running inference with compiled model...")
        with torch.no_grad():
            for i in range(10):
                res = model_trt(input)
                assert torch.allclose(
                    res, model(input), rtol=1e-2, atol=1e-2
                ), "Results do not match!"

    print("Inference successful!")