• Docs >
  • Using Custom Kernels with NVRTC in TensorRT AOT Plugins
Shortcuts

Using Custom Kernels with NVRTC in TensorRT AOT Plugins

This example demonstrates how to use the NVIDIA Runtime Compilation (NVRTC) library to compile custom CUDA kernels at runtime and integrate them into a TensorRT Ahead-Of-Time (AOT) plugin.

This approach is powerful because it allows you to: 1. Write raw CUDA C++ code for maximum performance. 2. Compile it on-the-fly, adapting to the specific GPU architecture. 3. Wrap it in a TensorRT plugin without writing a separate C++ plugin library. 4. Integrate it seamlessly into Torch-TensorRT’s compilation flow.

The example performs a simple pointwise Sigmoid operation: f(x) = 1 / (1 + exp(-x)).

from typing import List, Tuple, Union

import torch

import torch_tensorrt

# ============================================================================
# 1. Define the CUDA Kernel Source
# ============================================================================
# We define the CUDA kernel source code as a Python string.
# This code will be compiled by NVRTC.
# Note that we use extern "C" to avoid name mangling, making it easier to
# retrieve the kernel function by name later.

cu_code = """
// Simple pointwise Sigmoid kernel: f(x) = 1 / (1 + exp(-x))
extern "C" __global__ void pointwise_sigmoid_kernel_nvrtc(const float* __restrict__ input,
                                                          const int size,
                                                          float* __restrict__ output) {
    const int idx = blockIdx.x * blockDim.x + threadIdx.x;

    if (idx < size) {
        const float x = input[idx];
        // use fast device intrinsic to avoid headers
        output[idx] = 1.0f / (1.0f + __expf(-x));
    }
}
"""

# ============================================================================
# 2. Compile the Kernel using NVRTC (for eager mode)
# ============================================================================
# Before defining the Torch custom op, we compile the kernel so we can run it
# in standard PyTorch (eager mode) for verification and testing.
# We use the cuda-python library's NVRTC bindings.

from cuda.core.experimental import Device as _CudaDevice
from cuda.core.experimental import LaunchConfig as _LaunchConfig
from cuda.core.experimental import Program as _CudaProgram
from cuda.core.experimental import ProgramOptions as _CudaProgramOptions
from cuda.core.experimental import launch as _cuda_launch

# Initialize CUDA device and stream
_cuda_device = _CudaDevice()
_cuda_device.set_current()
_cuda_stream = _cuda_device.create_stream()

# Configure compilation options
_program_options = _CudaProgramOptions(
    std="c++17",
    arch=f"sm_{_cuda_device.arch}",  # Target the current GPU architecture
    include_path=["/usr/local/cuda/include"],
)

# Create and compile the program
_program = _CudaProgram(cu_code, code_type="c++", options=_program_options)
_module = _program.compile("ptx", name_expressions=("pointwise_sigmoid_kernel_nvrtc",))
_kernel = _module.get_kernel("pointwise_sigmoid_kernel_nvrtc")


# ============================================================================
# 3. Register Custom Op in PyTorch
# ============================================================================
# We register the custom operation with PyTorch so it can be used in models.
# The 'mutates_args=()' argument tells PyTorch this op is functional (doesn't modify inputs in-place).


@torch.library.custom_op("pointwise_sigmoid_ops::pointwise_sigmoid", mutates_args=())  # type: ignore[misc]
def pointwise_sigmoid(X: torch.Tensor) -> torch.Tensor:
    """
    Implementation of the custom op for PyTorch eager execution.
    This function launches the pre-compiled NVRTC kernel.
    """
    assert X.is_cuda, "Tensor must be on CUDA device."
    assert X.dtype == torch.float32, "For this test, expected float32 input"

    Y = torch.empty_like(X)
    N = int(X.numel())

    block = 256
    grid_x = max(1, (N + block - 1) // block)
    config = _LaunchConfig(grid=(grid_x), block=(block))

    # Helper class to wrap PyTorch's stream for cuda-python
    class _PyTorchStreamWrapper:
        def __init__(self, pt_stream):
            self.pt_stream = pt_stream

        def __cuda_stream__(self):
            stream_id = self.pt_stream.cuda_stream
            return (0, stream_id)

    pt_stream = torch.cuda.current_stream()
    s = _cuda_device.create_stream(_PyTorchStreamWrapper(pt_stream))

    # Launch kernel with raw pointers
    _cuda_launch(
        s,
        config,
        _kernel,
        X.data_ptr(),
        N,
        Y.data_ptr(),
    )

    return Y


# ============================================================================
# 4. Register Fake Implementation (Meta Kernel)
# ============================================================================
# The fake implementation is crucial for TorchDynamo. It tells the compiler
# about the output shape and data type without actually running the kernel.
# This is used during the tracing phase.


@torch.library.register_fake("pointwise_sigmoid_ops::pointwise_sigmoid")
def _(input: torch.Tensor) -> torch.Tensor:
    """Fake implementation for TorchDynamo tracing of base operation."""
    return torch.empty_like(input)


# ============================================================================
# 5. Define TensorRT AOT Plugin
# ============================================================================
# Now we define how this operation should be handled within TensorRT.
# We use the torch_tensorrt plugin auto generation feature and the AOT implementation using NVRTC.

import tensorrt.plugin as trtp
from cuda.core.experimental import Device, LaunchConfig, Program, ProgramOptions

torch_tensorrt.dynamo.conversion.plugins.generate_plugin(
    "pointwise_sigmoid_ops::pointwise_sigmoid"
)


# This is where the magic happens. We provide the compiled PTX code and
# launch parameters to TensorRT. This code runs during engine building.
@trtp.aot_impl("pointwise_sigmoid_ops::pointwise_sigmoid")
def sigmoid_aot_nvrtc_impl(
    X: trtp.TensorDesc,
    outputs: Tuple[trtp.TensorDesc],
    tactic: int,
) -> Tuple[
    Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs
]:
    # Get the PTX code from our pre-compiled module
    compiled_kernel = _module.code.decode("utf-8")

    # Calculate grid and block dimensions based on input shape
    N = X.shape_expr.numel()
    launch_params = trtp.KernelLaunchParams()
    block = 256
    launch_params.grid_x = trtp.cdiv(N, block)
    launch_params.block_x = block
    launch_params.shared_mem = 0

    # Pass the number of elements (N) as an extra argument to the kernel
    extra_args = trtp.SymIntExprs(1)
    extra_args[0] = trtp.SymInt32(N)

    # Return: kernel name, PTX code, launch params, kernel arguments
    return (
        "pointwise_sigmoid_kernel_nvrtc",
        compiled_kernel,
        launch_params,
        extra_args,
    )


# ============================================================================
# 6. Generate Plugin Converter
# ============================================================================
# This registers the mapping between the PyTorch custom op and the TensorRT plugin.
# It tells Torch-TensorRT: "When you see 'pointwise_sigmoid_ops::pointwise_sigmoid',
# replace it with the TensorRT plugin we just defined."

torch_tensorrt.dynamo.conversion.plugins.generate_plugin_converter(
    "pointwise_sigmoid_ops::pointwise_sigmoid",
    supports_dynamic_shapes=True,
    requires_output_allocator=False,
)


# ============================================================================
# 7. Test the Model
# ============================================================================


class PointwiseSigmoidModel_WithTRTWrapper(torch.nn.Module):
    """
    Test model that uses the TRT wrapper with custom_op() registration.
    """

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        z = torch.ops.pointwise_sigmoid_ops.pointwise_sigmoid(input)
        return z


if __name__ == "__main__":
    model = PointwiseSigmoidModel_WithTRTWrapper().to("cuda").eval()
    input = torch.randn(1, 1024, device="cuda", dtype=torch.float32)

    print("PyTorch baseline result:")
    print(torch.sigmoid(input))

    print("Custom Op eager result:")
    print(model(input))

    print("\nCompiling with Torch-TensorRT...")
    with torch_tensorrt.logging.debug():
        trt_inputs = [input]
        model_trt = torch_tensorrt.compile(
            model,
            inputs=trt_inputs,
            enabled_precisions={torch.float32},
            min_block_size=1,
        )
        print("Model compiled successfully!")

        print("Running inference with compiled model...")
        with torch.no_grad():
            for i in range(10):
                res = model_trt(input)
                assert torch.allclose(
                    res, model(input), rtol=1e-2, atol=1e-2
                ), "Results do not match!"

    print("Inference successful!")

Total running time of the script: ( 0 minutes 0.000 seconds)

Gallery generated by Sphinx-Gallery

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources