3D Rotary Position Embedding (RoPE) + Attention compiled with Torch-TensorRT#
3D RoPE is the positional encoding used in video generation transformers such as CogVideoX, Wan, and HunyuanVideo. Unlike 1D RoPE (used in language models) which encodes a single sequence index, 3D RoPE independently encodes three axes — temporal (T), height (H), and width (W) — and assigns each axis a dedicated slice of the per-head frequency vector:
head-dim slots 0 .. d//3-1 → temporal frequencies
head-dim slots d//3.. 2d//3-1 → height frequencies
head-dim slots 2d//3.. d//2-1 → width frequencies
The rotation is expressed with complex arithmetic:
xq_rotated = view_as_real(view_as_complex(xq) * freqs_cis)
PyTorch complex ops (view_as_complex, complex mul) are not natively supported by TensorRT. Torch-TensorRT’s complex_graph_detection lowering pass intercepts them before partitioning and rewrites the subgraph to equivalent real arithmetic — splitting the last dimension into (…, 2) real/imag pairs and computing (ac-bd, ad+bc) manually — so the TRT engine only sees standard float32 ops and the caller never needs to change anything.
This example: 1. Defines a 3D-RoPE frequency precomputation helper (complex64 output). 2. Defines a VideoAttentionBlock: linear QKV projection → 3D RoPE → SDPA. 3. Runs a PyTorch baseline forward pass. 4. Exports with torch.export.export() and dynamic T/H/W dimensions. 5. Compiles to TensorRT via torch_tensorrt.dynamo.compile(). 6. Verifies numerical accuracy (cosine similarity on the output tensor). 7. (Optional) benchmarks latency of both backends.
Usage#
Quick correctness check (static shapes)#
python examples/dynamo/torch_export_3d_rope.py
Dynamic T/H/W shapes#
python examples/dynamo/torch_export_3d_rope.py –dynamic
Larger config + benchmark#
python examples/dynamo/torch_export_3d_rope.py –heads 16 –head-dim 96 –t 8 –h 16 –w 16 –benchmark
[ ]:
import argparse
import timeit
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_tensorrt
from torch.export import Dim
DEVICE = torch.device("cuda:0")
# ---------------------------------------------------------------------------
# Frequency precomputation
# ---------------------------------------------------------------------------
def precompute_freqs_3d(
head_dim: int,
t: int,
h: int,
w: int,
theta: float = 10000.0,
) -> torch.Tensor:
"""Pre-compute 3D RoPE unit-complex frequency tensor.
Returns a complex64 tensor of shape (t, h, w, head_dim // 2) where the
last dimension is split evenly across the three spatial axes.
Args:
head_dim: Channels per attention head (must be even, head_dim//2
must be divisible by 3).
t: Number of temporal frames.
h: Spatial height in patches.
w: Spatial width in patches.
theta: Base for the geometric frequency progression.
"""
half = head_dim // 2
d_t = half // 3
d_h = half // 3
d_w = half - d_t - d_h # absorbs any remainder from integer division
def _axis_freqs(d: int, n: int) -> torch.Tensor:
"""1-D complex exponentials, shape (n, d)."""
inv_freq = 1.0 / (theta ** (torch.arange(0, d * 2, 2).float() / (d * 2)))
positions = torch.arange(n, dtype=torch.float32)
angles = torch.outer(positions, inv_freq)
return torch.polar(torch.ones_like(angles), angles) # complex64
freqs_t = _axis_freqs(d_t, t)[:, None, None, :].expand(t, h, w, d_t)
freqs_h = _axis_freqs(d_h, h)[None, :, None, :].expand(t, h, w, d_h)
freqs_w = _axis_freqs(d_w, w)[None, None, :, :].expand(t, h, w, d_w)
# Concatenate along last dim → (t, h, w, half), complex64
return torch.cat([freqs_t, freqs_h, freqs_w], dim=-1).contiguous()
# ---------------------------------------------------------------------------
# Model
# ---------------------------------------------------------------------------
class VideoAttentionBlock(nn.Module):
"""Single attention block for video latents with 3D RoPE.
Inputs
------
x : (B, T, H, W, C) float32 video patch features
freqs_cis_real: (T, H, W, C // n_heads) float32
The RoPE frequency tensor pre-flattened from complex64 via
``view_as_real(...).flatten(-2)``. The module reconstructs the
complex form internally with ``view_as_complex``.
Passing frequencies as a plain real-valued input avoids exposing a
complex tensor at the model boundary (TRT inputs must be real).
Output
------
(B, T, H, W, C) float32
"""
def __init__(self, channels: int = 512, n_heads: int = 8) -> None:
super().__init__()
assert channels % n_heads == 0
self.n_heads = n_heads
self.head_dim = channels // n_heads
self.norm = nn.LayerNorm(channels)
self.qkv = nn.Linear(channels, 3 * channels, bias=False)
self.proj = nn.Linear(channels, channels, bias=False)
def _apply_rope(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
"""Apply 3D RoPE to a single Q or K tensor.
The complex multiply ``xc * freqs_cis`` is what Torch-TensorRT rewrites
to real arithmetic via the complex_graph_detection lowering pass.
Args:
x : (B, T, H, W, n_heads, head_dim) float32
freqs_cis: (T, H, W, head_dim // 2) complex64
Returns:
Rotated tensor, same shape as ``x``, float32.
"""
B, T, H, W, Nh, D = x.shape
# Interpret consecutive pairs of head-dim channels as complex numbers.
xc = torch.view_as_complex(x.reshape(B, T, H, W, Nh, D // 2, 2))
# freqs_cis broadcast over batch (dim 0) and head (dim 4).
freqs = freqs_cis[None, :, :, :, None, :] # (1, T, H, W, 1, D//2)
return torch.view_as_real(xc * freqs).flatten(-2) # (B,T,H,W,Nh,D)
def forward(
self,
x: torch.Tensor,
freqs_cis_real: torch.Tensor,
) -> torch.Tensor:
B, T, H, W, C = x.shape
Nh, D = self.n_heads, self.head_dim
h = self.norm(x)
qkv = self.qkv(h).reshape(B, T, H, W, 3, Nh, D)
q, k, v = qkv.unbind(dim=4) # each (B, T, H, W, Nh, D)
# Recover complex frequencies from the real-valued input.
# freqs_cis_real: (T, H, W, D) → reshape to (T, H, W, D//2, 2) → complex
freqs_cis = torch.view_as_complex(freqs_cis_real.reshape(T, H, W, D // 2, 2))
q = self._apply_rope(q, freqs_cis)
k = self._apply_rope(k, freqs_cis)
# Flatten spatial dims for attention: (B, Nh, T*H*W, D)
N = T * H * W
q = q.reshape(B, N, Nh, D).permute(0, 2, 1, 3)
k = k.reshape(B, N, Nh, D).permute(0, 2, 1, 3)
v = v.reshape(B, N, Nh, D).permute(0, 2, 1, 3)
out = F.scaled_dot_product_attention(q, k, v) # (B, Nh, N, D)
out = out.permute(0, 2, 1, 3).reshape(B, T, H, W, C)
return x + self.proj(out)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def make_inputs(
B: int, T: int, H: int, W: int, C: int, n_heads: int
) -> tuple[torch.Tensor, torch.Tensor]:
"""Return (x, freqs_cis_real) on DEVICE."""
x = torch.randn(B, T, H, W, C, dtype=torch.float32, device=DEVICE)
freqs_cis = precompute_freqs_3d(C // n_heads, t=T, h=H, w=W).to(DEVICE)
freqs_cis_real = torch.view_as_real(freqs_cis).flatten(-2) # (T,H,W,D)
return x, freqs_cis_real
def benchmark(fn, *args, iterations: int = 20, label: str = "") -> float:
fn(*args) # warmup
torch.cuda.synchronize()
total = 0.0
for _ in range(iterations):
t0 = timeit.default_timer()
fn(*args)
torch.cuda.synchronize()
total += timeit.default_timer() - t0
avg_ms = total / iterations * 1000
print(f"[{label}] avg latency over {iterations} iters: {avg_ms:.2f} ms")
return avg_ms
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def parse_args():
p = argparse.ArgumentParser(
description="3D RoPE attention block compiled with Torch-TensorRT"
)
p.add_argument("--heads", type=int, default=8, help="Number of attention heads")
p.add_argument(
"--head-dim",
dest="head_dim",
type=int,
default=48,
help="Channels per head. head_dim//2 must be divisible by 3 (default: 48)",
)
p.add_argument("--t", type=int, default=4, help="Temporal frames (default: 4)")
p.add_argument(
"--h", type=int, default=8, help="Spatial height patches (default: 8)"
)
p.add_argument(
"--w", type=int, default=8, help="Spatial width patches (default: 8)"
)
p.add_argument(
"--dynamic",
action="store_true",
help="Export with dynamic T/H/W dims and compile with min/opt/max shapes",
)
p.add_argument(
"--benchmark", action="store_true", help="Benchmark PyTorch vs TRT latency"
)
p.add_argument("--iterations", type=int, default=20)
return p.parse_args()
def main():
args = parse_args()
if (args.head_dim // 2) % 3 != 0:
raise ValueError(
f"head_dim // 2 = {args.head_dim // 2} must be divisible by 3 "
"for the T/H/W frequency split. Try --head-dim 48, 60, 96, or 192."
)
B, T, H, W = 1, args.t, args.h, args.w
C = args.heads * args.head_dim
print(f"VideoAttentionBlock with 3D RoPE")
print(f" heads={args.heads} head_dim={args.head_dim} channels={C}")
print(f" input shape: ({B}, {T}, {H}, {W}, {C})")
model = VideoAttentionBlock(channels=C, n_heads=args.heads).eval().to(DEVICE)
# ------------------------------------------------------------------
# 1. Build inputs
# ------------------------------------------------------------------
x, freqs_cis_real = make_inputs(B, T, H, W, C, args.heads)
inputs = (x, freqs_cis_real)
print(f"\n x shape : {x.shape}")
print(f" freqs_cis_real shape: {freqs_cis_real.shape}")
# ------------------------------------------------------------------
# 2. PyTorch baseline
# ------------------------------------------------------------------
with torch.inference_mode():
pyt_out = model(*inputs)
print(f"\n--- PyTorch baseline ---")
print(f" output shape: {pyt_out.shape} dtype: {pyt_out.dtype}")
# ------------------------------------------------------------------
# 3. Export
# ------------------------------------------------------------------
print("\nExporting model ...")
if args.dynamic:
t_dim = Dim("T", min=1, max=32)
h_dim = Dim("H", min=4, max=64)
w_dim = Dim("W", min=4, max=64)
dynamic_shapes = (
# x: (B, T, H, W, C)
{1: t_dim, 2: h_dim, 3: w_dim},
# freqs_cis_real: (T, H, W, D)
{0: t_dim, 1: h_dim, 2: w_dim},
)
ep = torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes)
print(" Exported with dynamic T / H / W dimensions.")
else:
ep = torch.export.export(model, inputs)
print(" Exported with static shapes.")
# ------------------------------------------------------------------
# 4. Compile with Torch-TensorRT
#
# No special flags are required for the complex arithmetic rewrite.
# The complex_graph_detection lowering pass automatically detects
# view_as_complex / complex-mul / view_as_real subgraphs and rewrites
# them to real-arithmetic ops before the TRT engine is built.
# ------------------------------------------------------------------
print("\nCompiling with Torch-TensorRT ...")
D = C // args.heads # freqs_cis_real last dim
if args.dynamic:
trt_inputs = [
torch_tensorrt.Input(
min_shape=(B, 1, 4, 4, C),
opt_shape=(B, T, H, W, C),
max_shape=(B, 32, 64, 64, C),
dtype=torch.float32,
),
torch_tensorrt.Input(
min_shape=(1, 4, 4, D),
opt_shape=(T, H, W, D),
max_shape=(32, 64, 64, D),
dtype=torch.float32,
),
]
else:
trt_inputs = list(inputs)
trt_model = torch_tensorrt.dynamo.compile(
ep,
inputs=trt_inputs,
min_block_size=1,
)
# ------------------------------------------------------------------
# 5. TRT inference & accuracy check
# ------------------------------------------------------------------
with torch.inference_mode():
trt_out = trt_model(*inputs)
pyt_flat = pyt_out.float().flatten()
trt_flat = trt_out.float().flatten()
cos_sim = (pyt_flat @ trt_flat / (pyt_flat.norm() * trt_flat.norm())).item()
max_diff = (pyt_out.float() - trt_out.float()).abs().max().item()
print(f"\n--- TensorRT vs PyTorch ---")
print(f" output shape : {trt_out.shape}")
print(f" cosine sim : {cos_sim:.6f}")
print(f" max |Δ| : {max_diff:.2e}")
assert cos_sim > 0.99, f"Cosine similarity {cos_sim:.4f} below threshold 0.99!"
print(" PASSED")
# ------------------------------------------------------------------
# 6. (Optional) benchmark
# ------------------------------------------------------------------
if args.benchmark:
print("\n--- Benchmarking ---")
with torch.inference_mode():
benchmark(model, *inputs, iterations=args.iterations, label="PyTorch")
benchmark(trt_model, *inputs, iterations=args.iterations, label="TensorRT")
if __name__ == "__main__":
main()