Complex Tensor Support#
TensorRT does not natively support complex64 or complex128 tensors. Torch-TensorRT
handles them automatically via the complex_graph_rewrite lowering pass, which
rewrites complex-valued subgraphs into equivalent real-valued arithmetic before
compilation.
This page explains what the rewriter does, which patterns are supported, and what limitations to be aware of when compiling models with complex inputs.
How the Rewriter Works#
The complex_graph_rewrite pass runs as part of the standard lowering pipeline.
It:
Detects complex subgraphs by anchoring on
view_as_realnodes and walking backward through the graph to find all upstream complex operations.Replaces complex inputs with real-valued equivalents: -
placeholderinputs of typecomplex64/complex128are replaced by newfloat32/float64placeholders with an appended trailing dimension of size 2 (real and imaginary parts interleaved as(..., 2)).get_attrbuffers that are complex are replaced by a new buffer produced bytorch.stack([original.real, original.imag], dim=-1).
Rewrites complex multiply as explicit real arithmetic:
(a+bi) * (c+di) = (ac - bd) + (ad + bc)iBypasses
view_as_realandview_as_complexnodes — they become identity-like operations after the rewrite and are erased from the graph.
The net result is a fully real-valued graph that TRT can compile natively.
Supported Patterns#
The rewriter handles the following patterns inside a complex subgraph:
Pattern |
How it is handled |
|---|---|
|
Replaced by |
|
Replaced by stacked real+imag buffer with shape |
|
Rewritten to |
|
Erased (the input real tensor flows through unchanged) |
|
Erased (the output is already real after the rewrite) |
|
Handled — the trailing |
Usage#
No API changes are needed. The rewriter runs automatically whenever the exported graph contains complex-valued nodes:
import torch
import torch_tensorrt
class RoPEModel(torch.nn.Module):
def forward(self, x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
# x and freqs are real; view_as_complex converts to complex for mul
x_complex = torch.view_as_complex(x.reshape(*x.shape[:-1], -1, 2))
x_rotated = x_complex * freqs
return torch.view_as_real(x_rotated).flatten(2)
model = RoPEModel().eval().cuda()
x = torch.randn(1, 16, 64).cuda()
freqs = torch.randn(1, 16, 32, dtype=torch.complex64).cuda()
exp_program = torch.export.export(model, (x, freqs))
trt_gm = torch_tensorrt.dynamo.compile(
exp_program,
arg_inputs=[x, freqs],
use_explicit_typing=True, # enabled_precisions deprecated
min_block_size=1,
)
output = trt_gm(x, freqs)
The compiler detects the view_as_real node, walks the complex subgraph backward,
replaces the complex64 input freqs with a float32 placeholder of shape
(1, 16, 32, 2), and rewrites the multiply.
Passing complex inputs at runtime:
When the compiled model has complex input placeholders, pass the complex tensor directly.
The Torch-TensorRT runtime modules automatically call torch.view_as_real on complex
inputs before handing them to the TRT engine:
# freqs is still complex64 at call time — the runtime handles the conversion
output = trt_gm(x, freqs)
truncate_double#
By default, complex128 inputs are lowered to float64 (two doubles). Set
truncate_double=True in CompilationSettings Reference to truncate them to
float32 instead:
trt_gm = torch_tensorrt.dynamo.compile(
exp_program,
arg_inputs=inputs,
truncate_double=True, # complex128 → float32 (saves memory, loses precision)
)
Limitations#
Only ``view_as_real``-anchored subgraphs are detected. If your model uses complex arithmetic without
view_as_realas the output boundary (e.g. a complex output tensor is returned directly), the subgraph will not be detected and the compilation will fail.``view_as_complex`` must be paired with ``view_as_real`` in the same subgraph. Standalone
view_as_complexnodes outside the detected subgraph are not handled.No support for complex convolution or complex batch norm — only element-wise
mul.Tensoris rewritten. Complex convolution patterns must be decomposed manually into real arithmetic before compilation.``complex128`` on GPU requires
float64support in TRT. Most consumer GPUs have limitedfloat64throughput; usetruncate_double=Truefor performance-critical workloads.Parameters shaped ``(d, 2)`` (intentional, not complex) — if a real parameter happens to have a trailing dimension of 2 and is consumed by a node that the detector considers “complex”, it will not be mistakenly rewritten because the parameter’s dtype is real. The rewriter only rewrites nodes whose
meta["val"].dtypeiscomplex64orcomplex128.