The impl/ Building-Block Library#

torch_tensorrt.dynamo.conversion.impl is a library of pre-built TRT layer primitives that converter authors can compose rather than writing raw trt.INetworkDefinition API calls. Each module handles shape broadcasting, type coercion, and naming boilerplate internally.

All modules are available directly under impl:

from torch_tensorrt.dynamo.conversion import impl

# example: compose a GeLU approximation
x_sq = impl.elementwise.mul(ctx, target, name="x_sq", source_ir=SourceIR.ATEN,
                            lhs_val=x, rhs_val=x)

The standard call signature for most impl functions is:

def op(
    ctx: ConversionContext,
    target: Target,
    source_ir: Optional[SourceIR],
    name: str,
    # ... op-specific args ...
) -> TRTTensor:

source_ir is the SourceIR enum value that identifies which IR level this call originated from (used for debug naming). Pass SourceIR.ATEN from ATen converters.


Module Reference#

activation#

``impl.activation`` — Standard activation functions.

Functions: relu, sigmoid, tanh, leaky_relu, elu, selu, softsign, softplus, gelu, hard_sigmoid, hard_swish, prelu, hardshrink, softshrink, tanhshrink.

Each function wraps ctx.net.add_activation() with the appropriate trt.ActivationType and handles dynamic range propagation for INT8.

out = impl.activation.relu(ctx, target, SourceIR.ATEN, name, input_val)

addmm#

``impl.addmm`` — Fused add + matrix multiply (beta * input + alpha * (mat1 @ mat2)). Uses add_matrix_multiply + add_elementwise with scale factors.

arange#

``impl.arange`` — Generates a 1-D range tensor (torch.arange semantics). Implemented via TRT’s IFillLayer with LINSPACE fill.

cast#

``impl.cast`` — Explicit dtype casts. Wraps IIdentityLayer with set_output_type(). Used internally by the autocast pass and type-enforcement decorators.

cat#

``impl.cat`` — Tensor concatenation along a given dimension. Wraps ctx.net.add_concatenation().

conv / deconv#

``impl.conv`` / ``impl.deconv`` — Convolution and transposed convolution. Wraps add_convolution_nd / add_deconvolution_nd, handles 1D/2D/3D, padding modes, dilation, groups, and optional bias.

out = impl.conv.convolution(
    ctx, target, SourceIR.ATEN, name,
    input=x, weight=w, bias=b,
    stride=stride, padding=padding, dilation=dilation, groups=groups,
    transposed=False, output_padding=output_padding,
)

dynamic_block_quantize#

``impl.dynamic_block_quantize`` — Block-wise dynamic quantization helpers for FP8 and FP4 workflows. Wraps TRT’s quantization layers.

elementwise#

``impl.elementwise`` — Element-wise binary and unary operations.

Binary ops: add, sub, mul, div, pow, floor_div, trunc_div, fmod, logical_and, logical_or, logical_xor, bitwise_and, bitwise_or, bitwise_xor, eq, ne, gt, ge, lt, le, max, min.

Unary ops: abs, neg, floor, ceil, round, sqrt, rsqrt, exp, log, sin, cos, tan, asin, acos, atan, sinh, cosh, sign, not_.

All binary ops handle scalar arguments and broadcasting via get_trt_tensor/broadcastable_fn utilities.

embedding#

``impl.embedding`` — Implements torch.nn.Embedding and torch.nn.EmbeddingBag via add_gather (index gather on the weight matrix).

full#

``impl.full`` — Creates constant fill tensors (torch.full, torch.zeros, torch.ones) via IFillLayer.

grid#

``impl.grid`` — Grid sampling (torch.nn.functional.grid_sample) via TRT’s IGridSampleLayer.

linear#

``impl.linear`` — Fully-connected layer (torch.nn.Linear). Wraps add_matrix_multiply + optional add_elementwise for bias.

matmul#

``impl.matmul`` — General matrix multiplication. Handles batched matmul, dot products, and outer products by reshaping inputs before calling add_matrix_multiply.

nccl_ops#

``impl.nccl_ops`` — Fused NCCL collective wrappers for distributed inference (all_gather, reduce_scatter). Used by the fuse_distributed_ops lowering pass.

normalization#

``impl.normalization`` — Normalization layers: batch_norm, layer_norm, group_norm, instance_norm, rms_norm.

batch_norm handles constant-folded running stats (for inference) and delegates to add_scale for the affine transform.

pad#

``impl.pad`` — Tensor padding: constant_pad, reflection_pad, replication_pad. Wraps add_padding_nd or add_slice depending on padding mode.

permutation#

``impl.permutation``transpose, permute. Wraps add_shuffle with a permuted reshape.

pool#

``impl.pool`` — Pooling: avg_pool, max_pool, adaptive_avg_pool, adaptive_max_pool. Handles 1D/2D/3D and the ceil_mode / count_include_pad options via add_pooling_nd.

prelu#

``impl.prelu`` — Parametric ReLU. Implemented as max(0, x) + slope * min(0, x) using elementwise ops.

quantize#

``impl.quantize`` — Quantize / dequantize layers for INT8 and FP8 calibration workflows. Wraps add_quantize / add_dequantize.

reduce#

``impl.reduce`` — Reduction operations: sum, mean, max, min, prod, any, all, norm, var, std. Wraps add_reduce with appropriate trt.ReduceOperation.

select#

``impl.select`` — Index selection and slicing primitives: gather, index, index_select, gather_nd. Wraps add_gather variants.

shape#

``impl.shape`` — Shape introspection: size, numel, shape. Returns ITensor objects holding shape values for use in dynamic-shape graphs.

shuffle#

``impl.shuffle`` — Reshape and view via add_shuffle. Used by converters for view, reshape, flatten, unsqueeze, squeeze.

slice#

``impl.slice`` (available but not listed in __init__.py directly) — Slicing and strided access via add_slice. Used by converters for aten.slice.Tensor and aten.select.int.

split#

``impl.split`` — Splits a tensor into chunks along a dimension. Implemented via repeated add_slice.

squeeze / unsqueeze#

``impl.squeeze`` / ``impl.unsqueeze`` — Remove / add size-1 dimensions. Both delegate to impl.shuffle with a reshaped output spec.

topk#

``impl.topk``torch.topk. Wraps add_topk with ascending/descending support.

upsample#

``impl.upsample``torch.nn.functional.interpolate (nearest and bilinear). Wraps add_resize with the appropriate resize mode.


SourceIR#

Every impl function takes a source_ir: Optional[SourceIR] argument that tags the origin of the call for debug layer naming:

class SourceIR(Enum):
    ATEN    = auto()   # Called from an ATen converter
    TORCHSCRIPT = auto()  # Called from a TorchScript converter (legacy)
    NN      = auto()   # Called from an nn.Module-level converter
    ACC     = auto()   # Called from an ACC (operator fusion) converter
    PRIM    = auto()   # Called from a prims-level converter
    CORE_ATEN = auto() # Called from a Core ATen op converter
    UNKNOWN = auto()

The value is appended to the generated TRT layer name, making engine profiling and debugging easier. Always pass the appropriate value — use SourceIR.ATEN for all standard ATen-based converters.