torch_tensorrt.kernels#
torch_tensorrt.kernels (experimental)#
Register custom CUDA C++ kernels — compiled at runtime with NVRTC via
cuda-python — as TensorRT Quick Deployable Plugins (QDP). Tensor-only
declarative kernels use AOT plugin launches when available; kernels with
ScalarInput use TensorRT’s QDP JIT path so runtime scalar attributes can
be forwarded by value.
The module exposes a single registration entry point for source kernels:
cuda_kernel_op— fully declarative for the common cases, with optionaloverrides for everything else. Describe the kernel via
KernelSpec(inputs, outputs, extras, launch geometry) and the meta / eager / aot functions plus the PyTorch schema are derived for you. For shape-changing kernels, multi-output kernels, or anything outside the declarative DSL, passmeta_fn=/eager_fn=/aot_fn=/schema=keyword arguments and the correspondingKernelSpecfields become optional.ptx_op— register a kernel from pre-compiled PTX bytes. Skips NVRTCentirely; you supply
meta_fn/eager_fn/aot_fndirectly. Useful when the PTX comes from an external compiler (Triton, a cached NVRTC output, etc.).
Minimal example — declarative cuda_kernel_op:
import torch, torch_tensorrt
import torch_tensorrt.kernels as ttk
cu_code = """
extern "C" __global__ void my_relu(const float* x, int n, float* y) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) y[i] = x[i] > 0.f ? x[i] : 0.f;
}
"""
ttk.cuda_kernel_op(
"myns::relu",
ttk.KernelSpec(
kernel_source=cu_code,
kernel_name="my_relu",
inputs=[ttk.InputDecl("x")],
outputs=[ttk.OutputDecl("y", shape=ttk.SameAs(0))],
extras=[ttk.Numel("x")],
geometry=ttk.Elementwise(block=(256,), layout="flat"),
),
supports_dynamic_shapes=True,
)
For a shape-changing kernel, leave outputs / geometry off the
KernelSpec and pass hand-written meta_fn / eager_fn / aot_fn
to cuda_kernel_op directly — see examples/dynamo/cuda_kernel_op.py.
Note
This module is experimental. It requires cuda-python at runtime
and TensorRT >=10.7.0 (and not 10.14.x) for Quick Deployable
Plugin (QDP) support. Install cuda-python with pip install
cuda-python.
Overview#
The kernels module registers NVRTC-compiled CUDA C++ kernels as
TensorRT Quick Deployable Plugins. Tensor-only declarative kernels use
Ahead-of-Time (AOT) plugin launches when available; kernels with
ScalarInput compile through TensorRT’s QDP JIT path because QDP AOT
extra arguments currently support symbolic integer expressions, not
arbitrary runtime floats.
A single function — cuda_kernel_op() — handles both the declarative
case (drive everything from a KernelSpec dataclass) and the
override case (supply meta_fn / eager_fn / aot_fn / schema
keyword arguments when the declarative DSL doesn’t cover your kernel).
ptx_op() is a parallel entry point for kernels that are already
compiled to PTX bytes.
Entry points#
- torch_tensorrt.kernels.cuda_kernel_op(op_name: str, spec: KernelSpec, *, meta_fn: Callable[[...], Any] | None = None, eager_fn: Callable[[...], Any] | None = None, aot_fn: Callable[[...], Any] | None = None, schema: str | None = None, supports_dynamic_shapes: bool = True, requires_output_allocator: bool = False, priority: ConverterPriority = ConverterPriority.STANDARD, capability_validator: Callable[[...], Any] | None = None) None[source]#
Register a CUDA kernel as a TensorRT QDP plugin end-to-end.
Two paths share one entry point:
Declarative — pass a fully-populated
KernelSpecand the meta fn, eager fn, AOT fn, and PyTorch schema are all derived for you. Covers Elementwise / Reduction kernels out of the box.Override — pass any of
meta_fn/eager_fn/aot_fn/schemaas keyword arguments and the correspondingKernelSpecfields become optional. Use this for shape-changing kernels, multi-output kernels, or anything outside the declarative DSL.
Override rules (validated at registration time):
meta_fnprovided →spec.outputsmay be omitted.eager_fnandaot_fnboth provided →spec.geometrymay be omitted.schemaprovided → falls back to inferring fromspec.inputs/spec.outputsif both exist, else frommeta_fntype hints.
The kernel must follow the calling convention
(input_ptrs..., scalar_inputs..., extras..., output_ptrs...).
KernelSpec DSL#
- class torch_tensorrt.kernels.KernelSpec(kernel_source: str, kernel_name: str, inputs: ~typing.Sequence[~torch_tensorrt.kernels._dsl.InputDecl | ~torch_tensorrt.kernels._dsl.ScalarInput] | None = None, outputs: ~typing.Sequence[~torch_tensorrt.kernels._dsl.OutputDecl] | None = None, extras: ~typing.Sequence[~torch_tensorrt.kernels._dsl.Numel | ~torch_tensorrt.kernels._dsl.DimSize] = <factory>, geometry: ~torch_tensorrt.kernels._dsl.Elementwise | ~torch_tensorrt.kernels._dsl.Reduction | ~torch_tensorrt.kernels._dsl.Custom | None = None, include_paths: ~typing.List[str] | None = None, compile_std: str = 'c++17', arch_override: str | None = None)[source]#
Declarative description of a CUDA kernel.
Kernel signature convention: the
__global__function receives arguments in this fixed order — input pointers, then extras inextrasorder, then output pointers.All DSL fields beyond
kernel_source/kernel_nameare optional. Whichever fields are populated drive auto-derivation of the matching meta / eager / aot / schema artifacts; whatever is missing must be supplied as an override keyword argument tocuda_kernel_op().
- class torch_tensorrt.kernels.InputDecl(name: str, dtype: dtype | None = None)[source]#
Tensor kernel input.
The corresponding kernel argument is a
T*(data pointer) at the input pointer position in the calling convention.
- class torch_tensorrt.kernels.ScalarInput(name: str, py_type: type)[source]#
Scalar (non-tensor) kernel input — e.g.
float alphaorint k.Scalars are forwarded by value to the kernel at the input position (after all preceding tensor/scalar inputs, before extras and output pointers).
py_typemust befloat,int, orbool.
ScalarInput values are represented as TensorRT plugin attributes during
compilation and are forwarded by value to the CUDA kernel. Tensor-only
cuda_kernel_op registrations use AOT plugin launches; registrations with
ScalarInput use QDP JIT plugin execution so scalar floats / ints / bools can
be passed correctly.
- class torch_tensorrt.kernels.OutputDecl(name: 'str', shape: 'ShapeRel', dtype_from: 'Optional[str]' = None)[source]#
Shape relations#
- class torch_tensorrt.kernels.SameAs(input_idx: int | str = 0)[source]#
Output has the same shape as the referenced tensor input.
input_idxmay be either the integer position into the tensor-only input list (ScalarInputentries are skipped) or thenameof a tensor input declared viaInputDecl. The name form is preferred because it stays correct when the input list is reordered.
- class torch_tensorrt.kernels.ReduceDims(input_idx: int | str, dims: Tuple[int, ...], keepdim: bool = False)[source]#
Output = the referenced tensor input with
dimsremoved.If
keepdim=Truethose axes are kept with size 1 instead of removed. Negative axes are allowed.input_idxaccepts either the integer position into the tensor-only input list or the inputname.
Extra scalar args#
Extras are passed to the kernel between the input and output pointer
lists in KernelSpec order.
Launch geometry#
- class torch_tensorrt.kernels.Elementwise(block: Tuple[int, ...] = (256,), layout: Literal['flat', 'nd'] = 'flat')[source]#
One thread per output element.
layout="flat": 1D launch over the flattened outputnumel.block = (bx,)→grid = (cdiv(numel(out), bx),).layout="nd": the trailinglen(block)axes of the output areblock-parallelized; any leading axes are folded into
grid_z.block[0]maps to the last (innermost) axis, matching CUDA’s convention thatgrid_x/block_xvaries fastest.
Override path#
Pass any of the optional keyword arguments to cuda_kernel_op() to
bypass the corresponding auto-derivation:
meta_fn— fake/meta impl: shape + dtype inference for tracing. When supplied,spec.outputsmay be omitted.eager_fn— CUDA device impl invoked when the op runs in PyTorch eager. Same positional signature asmeta_fn.aot_fn— TensorRT AOT impl with signature(inputs, outputs, tactic) -> (KernelLaunchParams, SymExprs | None). When botheager_fnandaot_fnare supplied,spec.geometrymay be omitted.schema— explicit Torch schema (for example"(Tensor x, float alpha) -> Tensor"). Falls back to deriving fromspec.inputs/spec.outputsif both are present, else to inferring frommeta_fntype hints.
Use the override path for shape-changing kernels, multi-output kernels, or anything that doesn’t fit the Elementwise / Reduction conventions.
Pre-compiled PTX entry point#
- torch_tensorrt.kernels.ptx_op(op_name: str, ptx: bytes, kernel_name: str, meta_fn: Callable[[...], Any], eager_fn: Callable[[...], Any], aot_fn: Callable[[...], Any], *, supports_dynamic_shapes: bool = False, requires_output_allocator: bool = False, priority: ConverterPriority = ConverterPriority.STANDARD, capability_validator: Callable[[...], Any] | None = None, schema: str | None = None) None[source]#
Register a pre-compiled PTX kernel as a TensorRT QDP plugin.
Use this when the PTX comes from an external compiler (Triton, a cached NVRTC output, etc.) and NVRTC compilation should be skipped.
Kernel signature convention#
All entry points assume the __global__ kernel takes its arguments in
the fixed order:
(input_ptrs..., extras..., output_ptrs...)
Pointers are void* cast to the appropriate element type. Extras
follow the order declared in :attr:`KernelSpec.extras` for the
declarative path, or the order your aot_fn builds for the override
path.
Error behavior#
cuda_kernel_op() validates the KernelSpec at registration
time and raises ValueError for the common authoring mistakes:
Empty or duplicate-named
inputs/outputs.ReduceDims(input_idx=...)orSameAs(input_idx=...)where the reference is an out-of-range integer or a name that is not a tensor input. Both forms are accepted: an integer position into the tensor-only input list, or the inputname(preferred — survives input reordering).Numel/DimSizereferencing a name that is not an input.dtype_frompointing at an unknown input.Elementwise(layout='flat')with a multi-dimensional block tuple.Invalid block sizes,
block_sizeinReduction, or a non-callable :attr:`Custom.fn`.A DSL field is missing and the corresponding override keyword argument was not supplied (e.g.
outputsomitted without ameta_fn).
Shape-dependent errors — for example
Elementwise(layout='nd', block=(16, 16)) invoked against a 1-D
output — are raised at launch time in a clear ValueError because
the offending ranks are only known when concrete tensors arrive.