Complex Number Support#
Note
This page documents the design for complex number support in Torch-TensorRT. Original design discussion: RFC #3456.
Goal#
TensorRT does not natively support complex-dtype tensors (torch.complex64,
torch.complex128). Complex numbers appear in models that use rotary position
embeddings (RoPE), for example in Llama 3, where frequency vectors are computed
in polar form (torch.polar) and applied via complex multiplication.
The goal is to allow such models to be compiled end-to-end by Torch-TensorRT through a graph-rewrite lowering pass that eliminates all complex-dtype nodes before the graph reaches TensorRT.
The primary motivation was enabling end-to-end compilation of Llama 3 in
distributed (multi-GPU) settings where the torch.compile + distributed-tensor
workflow hoists freqs_cis (a complex64 tensor) to a graph input.
Rotary Embedding Pattern#
The canonical complex-number subgraph in RoPE looks like:
def apply_rotary_emb(xq, xk, freqs_cis):
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
After export+lowering the critical sub-pattern is:
placeholder (complex freq) ──► reshape ──► mul (complex) ──► view_as_real
placeholder (real xq) ──► view_as_complex ──┘
Implementation Overview#
The rewrite is a lowering pass in
py/torch_tensorrt/dynamo/lowering/passes/complex_graph_rewrite.py.
It operates in three conceptual stages:
Stage 1 — Detection#
The pass anchors on view_as_real nodes and walks backward through the graph
to identify all nodes participating in complex arithmetic. A node is included in
the complex subgraph if its output dtype is complex or if it is view_as_complex.
The resulting ComplexOpSubGraphInfo records:
anchor_nodes— theview_as_realnodes that terminate the complex subgraph.subgraph_nodes— all nodes between the inputs and the anchors.input_nodes— nodes feeding into the subgraph from outside.
Stage 2 — Input Node Replacement#
Each complex input node is replaced with a real-dtype equivalent:
``get_attr`` buffers (constant complex tensors): a new
_unpacked_complexbuffer is registered on the graph module usingtorch.stack([real, imag], dim=-1), which has dtypefloat32and one additional trailing dimension of size 2.``placeholder`` inputs (runtime complex tensors): the placeholder’s metadata (
meta["val"]) is updated to reflect the newfloat32shape with the appended2dimension. SymInt dynamic dimensions are preserved.
Stage 3 — Subgraph Rewrite#
Once inputs are real, the complex ops within the subgraph are rewritten:
``view_as_complex`` — erased (the input is already real with trailing dim 2).
``view_as_real`` — erased (the output is already real).
``aten.mul.Tensor`` on complex tensors — replaced with the manual complex-multiplication identity:
\[(a + bi)(c + di) = (ac - bd) + (ad + bc)i\]Implemented as:
# a, b = real/imag parts of left operand (shape [..., 2]) # c, d = real/imag parts of right operand (shape [..., 2]) real = a * c - b * d imag = a * d + b * c result = torch.stack([real, imag], dim=-1)
``permute`` on complex tensors — the dims list is extended by appending the original last dimension index so the trailing
2dimension (real/imag) is permuted correctly.``reshape``/``slice`` — trailing-dimension arguments are updated to account for the new
...×2layout.
Runtime Changes#
At runtime the TRT engine receives a real-valued tensor with shape
(*orig_complex_shape, 2) instead of the original complex tensor. The three
runtime modules handle the conversion:
prepare_inputs(dynamo/utils.py) — builds theInputspec with theview_as_realshape/dtype but retains the original complex tensor ininp.torch_tensorfor tracing._PythonTorchTensorRTModule.forward— appliestorch.view_as_real(i).contiguous()for each complex input before feeding it to the engine._TorchTensorRTModule.forward— sameview_as_realconversion.
Key Implementation Invariants#
``originally_complex`` set — the set of nodes that were complex-dtype before any rewrites. After
replace_input_node, complex placeholders becomefloat32sois_complex_dtype()returnsFalse. Theoriginally_complexset is used to decide whichmul.Tensornodes need the complex mul rewrite.FakeTensorMode reuse —
propagate_metadatamust use theFakeTensorModefrom existing placeholder fake tensors (not a fresh mode) to avoid mode-mismatch errors undertorch.compileand to preserve SymInt for dynamic shapes.Dotted buffer names —
register_bufferrejects names containing.. Nested submodule parameter names (e.g.layers.0.weight) must have.replaced with__before registration.