torch_tensorrt.kernels#

torch_tensorrt.kernels (experimental)#

Register custom CUDA C++ kernels — compiled at runtime with NVRTC via cuda-python — as TensorRT Quick Deployable Plugins (QDP). Tensor-only declarative kernels use AOT plugin launches when available; kernels with ScalarInput use TensorRT’s QDP JIT path so runtime scalar attributes can be forwarded by value.

The module exposes a single registration entry point for source kernels:

cuda_kernel_op — fully declarative for the common cases, with optional

overrides for everything else. Describe the kernel via KernelSpec (inputs, outputs, extras, launch geometry) and the meta / eager / aot functions plus the PyTorch schema are derived for you. For shape-changing kernels, multi-output kernels, or anything outside the declarative DSL, pass meta_fn= / eager_fn= / aot_fn= / schema= keyword arguments and the corresponding KernelSpec fields become optional.

ptx_op — register a kernel from pre-compiled PTX bytes. Skips NVRTC

entirely; you supply meta_fn / eager_fn / aot_fn directly. Useful when the PTX comes from an external compiler (Triton, a cached NVRTC output, etc.).

Minimal example — declarative cuda_kernel_op:

import torch, torch_tensorrt
import torch_tensorrt.kernels as ttk

cu_code = """
extern "C" __global__ void my_relu(const float* x, int n, float* y) {
    int i = blockIdx.x * blockDim.x + threadIdx.x;
    if (i < n) y[i] = x[i] > 0.f ? x[i] : 0.f;
}
"""

ttk.cuda_kernel_op(
    "myns::relu",
    ttk.KernelSpec(
        kernel_source=cu_code,
        kernel_name="my_relu",
        inputs=[ttk.InputDecl("x")],
        outputs=[ttk.OutputDecl("y", shape=ttk.SameAs(0))],
        extras=[ttk.Numel("x")],
        geometry=ttk.Elementwise(block=(256,), layout="flat"),
    ),
    supports_dynamic_shapes=True,
)

For a shape-changing kernel, leave outputs / geometry off the KernelSpec and pass hand-written meta_fn / eager_fn / aot_fn to cuda_kernel_op directly — see examples/dynamo/cuda_kernel_op.py.

Note

This module is experimental. It requires cuda-python at runtime and TensorRT >=10.7.0 (and not 10.14.x) for Quick Deployable Plugin (QDP) support. Install cuda-python with pip install cuda-python.

Overview#

The kernels module registers NVRTC-compiled CUDA C++ kernels as TensorRT Quick Deployable Plugins. Tensor-only declarative kernels use Ahead-of-Time (AOT) plugin launches when available; kernels with ScalarInput compile through TensorRT’s QDP JIT path because QDP AOT extra arguments currently support symbolic integer expressions, not arbitrary runtime floats.

A single function — cuda_kernel_op() — handles both the declarative case (drive everything from a KernelSpec dataclass) and the override case (supply meta_fn / eager_fn / aot_fn / schema keyword arguments when the declarative DSL doesn’t cover your kernel). ptx_op() is a parallel entry point for kernels that are already compiled to PTX bytes.

Entry points#

torch_tensorrt.kernels.cuda_kernel_op(op_name: str, spec: KernelSpec, *, meta_fn: Callable[[...], Any] | None = None, eager_fn: Callable[[...], Any] | None = None, aot_fn: Callable[[...], Any] | None = None, schema: str | None = None, supports_dynamic_shapes: bool = True, requires_output_allocator: bool = False, priority: ConverterPriority = ConverterPriority.STANDARD, capability_validator: Callable[[...], Any] | None = None) None[source]#

Register a CUDA kernel as a TensorRT QDP plugin end-to-end.

Two paths share one entry point:

  • Declarative — pass a fully-populated KernelSpec and the meta fn, eager fn, AOT fn, and PyTorch schema are all derived for you. Covers Elementwise / Reduction kernels out of the box.

  • Override — pass any of meta_fn / eager_fn / aot_fn / schema as keyword arguments and the corresponding KernelSpec fields become optional. Use this for shape-changing kernels, multi-output kernels, or anything outside the declarative DSL.

Override rules (validated at registration time):

  • meta_fn provided → spec.outputs may be omitted.

  • eager_fn and aot_fn both provided → spec.geometry may be omitted.

  • schema provided → falls back to inferring from spec.inputs / spec.outputs if both exist, else from meta_fn type hints.

The kernel must follow the calling convention (input_ptrs..., scalar_inputs..., extras..., output_ptrs...).

KernelSpec DSL#

class torch_tensorrt.kernels.KernelSpec(kernel_source: str, kernel_name: str, inputs: ~typing.Sequence[~torch_tensorrt.kernels._dsl.InputDecl | ~torch_tensorrt.kernels._dsl.ScalarInput] | None = None, outputs: ~typing.Sequence[~torch_tensorrt.kernels._dsl.OutputDecl] | None = None, extras: ~typing.Sequence[~torch_tensorrt.kernels._dsl.Numel | ~torch_tensorrt.kernels._dsl.DimSize] = <factory>, geometry: ~torch_tensorrt.kernels._dsl.Elementwise | ~torch_tensorrt.kernels._dsl.Reduction | ~torch_tensorrt.kernels._dsl.Custom | None = None, include_paths: ~typing.List[str] | None = None, compile_std: str = 'c++17', arch_override: str | None = None)[source]#

Declarative description of a CUDA kernel.

Kernel signature convention: the __global__ function receives arguments in this fixed order — input pointers, then extras in extras order, then output pointers.

All DSL fields beyond kernel_source / kernel_name are optional. Whichever fields are populated drive auto-derivation of the matching meta / eager / aot / schema artifacts; whatever is missing must be supplied as an override keyword argument to cuda_kernel_op().

class torch_tensorrt.kernels.InputDecl(name: str, dtype: dtype | None = None)[source]#

Tensor kernel input.

The corresponding kernel argument is a T* (data pointer) at the input pointer position in the calling convention.

class torch_tensorrt.kernels.ScalarInput(name: str, py_type: type)[source]#

Scalar (non-tensor) kernel input — e.g. float alpha or int k.

Scalars are forwarded by value to the kernel at the input position (after all preceding tensor/scalar inputs, before extras and output pointers). py_type must be float, int, or bool.

ScalarInput values are represented as TensorRT plugin attributes during compilation and are forwarded by value to the CUDA kernel. Tensor-only cuda_kernel_op registrations use AOT plugin launches; registrations with ScalarInput use QDP JIT plugin execution so scalar floats / ints / bools can be passed correctly.

class torch_tensorrt.kernels.OutputDecl(name: 'str', shape: 'ShapeRel', dtype_from: 'Optional[str]' = None)[source]#

Shape relations#

class torch_tensorrt.kernels.SameAs(input_idx: int | str = 0)[source]#

Output has the same shape as the referenced tensor input.

input_idx may be either the integer position into the tensor-only input list (ScalarInput entries are skipped) or the name of a tensor input declared via InputDecl. The name form is preferred because it stays correct when the input list is reordered.

class torch_tensorrt.kernels.ReduceDims(input_idx: int | str, dims: Tuple[int, ...], keepdim: bool = False)[source]#

Output = the referenced tensor input with dims removed.

If keepdim=True those axes are kept with size 1 instead of removed. Negative axes are allowed. input_idx accepts either the integer position into the tensor-only input list or the input name.

Extra scalar args#

Extras are passed to the kernel between the input and output pointer lists in KernelSpec order.

class torch_tensorrt.kernels.Numel(input_name: str)[source]#

Pass inputs[input_name].numel() as an int extra.

class torch_tensorrt.kernels.DimSize(input_name: str, axis: int)[source]#

Pass inputs[input_name].shape[axis] as an int extra.

Negative axis allowed.

Launch geometry#

class torch_tensorrt.kernels.Elementwise(block: Tuple[int, ...] = (256,), layout: Literal['flat', 'nd'] = 'flat')[source]#

One thread per output element.

layout="flat": 1D launch over the flattened output numel.

block = (bx,)grid = (cdiv(numel(out), bx),).

layout="nd": the trailing len(block) axes of the output are

block-parallelized; any leading axes are folded into grid_z. block[0] maps to the last (innermost) axis, matching CUDA’s convention that grid_x / block_x varies fastest.

class torch_tensorrt.kernels.Reduction(reduce_dims: Tuple[int, ...], block_size: int = 256)[source]#

One block per output element; block threads cooperate across the reduction axes. reduce_dims are axes of the input (not output) that are collapsed. Grid = numel(output), block = block_size.

class torch_tensorrt.kernels.Custom(fn: Callable[[...], Any])[source]#

Escape hatch. fn(inputs, outputs, tactic) returns the same shape as today’s hand-written aot_fn: (KernelLaunchParams, SymExprs).

Override path#

Pass any of the optional keyword arguments to cuda_kernel_op() to bypass the corresponding auto-derivation:

  • meta_fn — fake/meta impl: shape + dtype inference for tracing. When supplied, spec.outputs may be omitted.

  • eager_fn — CUDA device impl invoked when the op runs in PyTorch eager. Same positional signature as meta_fn.

  • aot_fn — TensorRT AOT impl with signature (inputs, outputs, tactic) -> (KernelLaunchParams, SymExprs | None). When both eager_fn and aot_fn are supplied, spec.geometry may be omitted.

  • schema — explicit Torch schema (for example "(Tensor x, float alpha) -> Tensor"). Falls back to deriving from spec.inputs / spec.outputs if both are present, else to inferring from meta_fn type hints.

Use the override path for shape-changing kernels, multi-output kernels, or anything that doesn’t fit the Elementwise / Reduction conventions.

Pre-compiled PTX entry point#

torch_tensorrt.kernels.ptx_op(op_name: str, ptx: bytes, kernel_name: str, meta_fn: Callable[[...], Any], eager_fn: Callable[[...], Any], aot_fn: Callable[[...], Any], *, supports_dynamic_shapes: bool = False, requires_output_allocator: bool = False, priority: ConverterPriority = ConverterPriority.STANDARD, capability_validator: Callable[[...], Any] | None = None, schema: str | None = None) None[source]#

Register a pre-compiled PTX kernel as a TensorRT QDP plugin.

Use this when the PTX comes from an external compiler (Triton, a cached NVRTC output, etc.) and NVRTC compilation should be skipped.

Kernel signature convention#

All entry points assume the __global__ kernel takes its arguments in the fixed order:

(input_ptrs..., extras..., output_ptrs...)

Pointers are void* cast to the appropriate element type. Extras follow the order declared in :attr:`KernelSpec.extras` for the declarative path, or the order your aot_fn builds for the override path.

Error behavior#

cuda_kernel_op() validates the KernelSpec at registration time and raises ValueError for the common authoring mistakes:

  • Empty or duplicate-named inputs / outputs.

  • ReduceDims(input_idx=...) or SameAs(input_idx=...) where the reference is an out-of-range integer or a name that is not a tensor input. Both forms are accepted: an integer position into the tensor-only input list, or the input name (preferred — survives input reordering).

  • Numel / DimSize referencing a name that is not an input.

  • dtype_from pointing at an unknown input.

  • Elementwise(layout='flat') with a multi-dimensional block tuple.

  • Invalid block sizes, block_size in Reduction, or a non-callable :attr:`Custom.fn`.

  • A DSL field is missing and the corresponding override keyword argument was not supplied (e.g. outputs omitted without a meta_fn).

Shape-dependent errors — for example Elementwise(layout='nd', block=(16, 16)) invoked against a 1-D output — are raised at launch time in a clear ValueError because the offending ranks are only known when concrete tensors arrive.