Complex Number Support#

Note

This page documents the design for complex number support in Torch-TensorRT. Original design discussion: RFC #3456.

Goal#

TensorRT does not natively support complex-dtype tensors (torch.complex64, torch.complex128). Complex numbers appear in models that use rotary position embeddings (RoPE), for example in Llama 3, where frequency vectors are computed in polar form (torch.polar) and applied via complex multiplication.

The goal is to allow such models to be compiled end-to-end by Torch-TensorRT through a graph-rewrite lowering pass that eliminates all complex-dtype nodes before the graph reaches TensorRT.

The primary motivation was enabling end-to-end compilation of Llama 3 in distributed (multi-GPU) settings where the torch.compile + distributed-tensor workflow hoists freqs_cis (a complex64 tensor) to a graph input.

Rotary Embedding Pattern#

The canonical complex-number subgraph in RoPE looks like:

def apply_rotary_emb(xq, xk, freqs_cis):
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

After export+lowering the critical sub-pattern is:

placeholder (complex freq) ──► reshape ──► mul (complex) ──► view_as_real
placeholder (real xq)      ──► view_as_complex ──┘

Implementation Overview#

The rewrite is a lowering pass in py/torch_tensorrt/dynamo/lowering/passes/complex_graph_rewrite.py. It operates in three conceptual stages:

Complex number support overview diagram

Stage 1 — Detection#

The pass anchors on view_as_real nodes and walks backward through the graph to identify all nodes participating in complex arithmetic. A node is included in the complex subgraph if its output dtype is complex or if it is view_as_complex.

The resulting ComplexOpSubGraphInfo records:

  • anchor_nodes — the view_as_real nodes that terminate the complex subgraph.

  • subgraph_nodes — all nodes between the inputs and the anchors.

  • input_nodes — nodes feeding into the subgraph from outside.

Stage 2 — Input Node Replacement#

Each complex input node is replaced with a real-dtype equivalent:

  • ``get_attr`` buffers (constant complex tensors): a new _unpacked_complex buffer is registered on the graph module using torch.stack([real, imag], dim=-1), which has dtype float32 and one additional trailing dimension of size 2.

  • ``placeholder`` inputs (runtime complex tensors): the placeholder’s metadata (meta["val"]) is updated to reflect the new float32 shape with the appended 2 dimension. SymInt dynamic dimensions are preserved.

Graph rewrite — source pattern (blue arrows = modifications)

Stage 3 — Subgraph Rewrite#

Once inputs are real, the complex ops within the subgraph are rewritten:

  • ``view_as_complex`` — erased (the input is already real with trailing dim 2).

  • ``view_as_real`` — erased (the output is already real).

  • ``aten.mul.Tensor`` on complex tensors — replaced with the manual complex-multiplication identity:

    \[(a + bi)(c + di) = (ac - bd) + (ad + bc)i\]

    Implemented as:

    # a, b = real/imag parts of left operand (shape [..., 2])
    # c, d = real/imag parts of right operand (shape [..., 2])
    real = a * c - b * d
    imag = a * d + b * c
    result = torch.stack([real, imag], dim=-1)
    
  • ``permute`` on complex tensors — the dims list is extended by appending the original last dimension index so the trailing 2 dimension (real/imag) is permuted correctly.

  • ``reshape``/``slice`` — trailing-dimension arguments are updated to account for the new ...×2 layout.

Modified target graph after complex rewrite

Runtime Changes#

At runtime the TRT engine receives a real-valued tensor with shape (*orig_complex_shape, 2) instead of the original complex tensor. The three runtime modules handle the conversion:

  • prepare_inputs (dynamo/utils.py) — builds the Input spec with the view_as_real shape/dtype but retains the original complex tensor in inp.torch_tensor for tracing.

  • _PythonTorchTensorRTModule.forward — applies torch.view_as_real(i).contiguous() for each complex input before feeding it to the engine.

  • _TorchTensorRTModule.forward — same view_as_real conversion.

Key Implementation Invariants#

  • ``originally_complex`` set — the set of nodes that were complex-dtype before any rewrites. After replace_input_node, complex placeholders become float32 so is_complex_dtype() returns False. The originally_complex set is used to decide which mul.Tensor nodes need the complex mul rewrite.

  • FakeTensorMode reusepropagate_metadata must use the FakeTensorMode from existing placeholder fake tensors (not a fresh mode) to avoid mode-mismatch errors under torch.compile and to preserve SymInt for dynamic shapes.

  • Dotted buffer namesregister_buffer rejects names containing .. Nested submodule parameter names (e.g. layers.0.weight) must have . replaced with __ before registration.