Complex Tensor Support#

TensorRT does not natively support complex64 or complex128 tensors. Torch-TensorRT handles them automatically via the complex_graph_rewrite lowering pass, which rewrites complex-valued subgraphs into equivalent real-valued arithmetic before compilation.

This page explains what the rewriter does, which patterns are supported, and what limitations to be aware of when compiling models with complex inputs.


How the Rewriter Works#

The complex_graph_rewrite pass runs as part of the standard lowering pipeline. It:

  1. Detects complex subgraphs by anchoring on view_as_real nodes and walking backward through the graph to find all upstream complex operations.

  2. Replaces complex inputs with real-valued equivalents: - placeholder inputs of type complex64/complex128 are replaced by new

    float32/float64 placeholders with an appended trailing dimension of size 2 (real and imaginary parts interleaved as (..., 2)).

    • get_attr buffers that are complex are replaced by a new buffer produced by torch.stack([original.real, original.imag], dim=-1).

  3. Rewrites complex multiply as explicit real arithmetic: (a+bi) * (c+di) = (ac - bd) + (ad + bc)i

  4. Bypasses view_as_real and view_as_complex nodes — they become identity-like operations after the rewrite and are erased from the graph.

The net result is a fully real-valued graph that TRT can compile natively.


Supported Patterns#

The rewriter handles the following patterns inside a complex subgraph:

Pattern

How it is handled

complex64 / complex128 input placeholder

Replaced by float32 / float64 placeholder with shape (..., 2)

complex64 / complex128 model buffer (get_attr)

Replaced by stacked real+imag buffer with shape (..., 2)

aten.mul.Tensor between complex tensors

Rewritten to (ac - bd) + (ad + bc)i real arithmetic

aten.view_as_complex nodes

Erased (the input real tensor flows through unchanged)

aten.view_as_real anchor nodes

Erased (the output is already real after the rewrite)

aten.permute.default on complex tensors

Handled — the trailing 2 dimension is appended to the dims list


Usage#

No API changes are needed. The rewriter runs automatically whenever the exported graph contains complex-valued nodes:

import torch
import torch_tensorrt

class RoPEModel(torch.nn.Module):
    def forward(self, x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
        # x and freqs are real; view_as_complex converts to complex for mul
        x_complex = torch.view_as_complex(x.reshape(*x.shape[:-1], -1, 2))
        x_rotated = x_complex * freqs
        return torch.view_as_real(x_rotated).flatten(2)

model = RoPEModel().eval().cuda()
x = torch.randn(1, 16, 64).cuda()
freqs = torch.randn(1, 16, 32, dtype=torch.complex64).cuda()

exp_program = torch.export.export(model, (x, freqs))
trt_gm = torch_tensorrt.dynamo.compile(
    exp_program,
    arg_inputs=[x, freqs],
    use_explicit_typing=True,  # enabled_precisions deprecated
    min_block_size=1,
)

output = trt_gm(x, freqs)

The compiler detects the view_as_real node, walks the complex subgraph backward, replaces the complex64 input freqs with a float32 placeholder of shape (1, 16, 32, 2), and rewrites the multiply.

Passing complex inputs at runtime:

When the compiled model has complex input placeholders, pass the complex tensor directly. The Torch-TensorRT runtime modules automatically call torch.view_as_real on complex inputs before handing them to the TRT engine:

# freqs is still complex64 at call time — the runtime handles the conversion
output = trt_gm(x, freqs)

truncate_double#

By default, complex128 inputs are lowered to float64 (two doubles). Set truncate_double=True in CompilationSettings Reference to truncate them to float32 instead:

trt_gm = torch_tensorrt.dynamo.compile(
    exp_program,
    arg_inputs=inputs,
    truncate_double=True,   # complex128 → float32 (saves memory, loses precision)
)

Limitations#

  • Only ``view_as_real``-anchored subgraphs are detected. If your model uses complex arithmetic without view_as_real as the output boundary (e.g. a complex output tensor is returned directly), the subgraph will not be detected and the compilation will fail.

  • ``view_as_complex`` must be paired with ``view_as_real`` in the same subgraph. Standalone view_as_complex nodes outside the detected subgraph are not handled.

  • No support for complex convolution or complex batch norm — only element-wise mul.Tensor is rewritten. Complex convolution patterns must be decomposed manually into real arithmetic before compilation.

  • ``complex128`` on GPU requires float64 support in TRT. Most consumer GPUs have limited float64 throughput; use truncate_double=True for performance-critical workloads.

  • Parameters shaped ``(d, 2)`` (intentional, not complex) — if a real parameter happens to have a trailing dimension of 2 and is consumed by a node that the detector considers “complex”, it will not be mistakenly rewritten because the parameter’s dtype is real. The rewriter only rewrites nodes whose meta["val"].dtype is complex64 or complex128.