SubgraphBuilder — Cursor-Based FX Node Insertion#
Writing lowering passes that replace one node with several new nodes requires
careful management of insertion order: each new node must be inserted
after the previous one so that the topological ordering of the graph is
preserved. Doing this by hand with repeated graph.inserting_after(cursor)
context managers is verbose and error-prone.
SubgraphBuilder is a lightweight context-manager helper in
torch_tensorrt.dynamo.lowering._SubgraphBuilder that automates this
cursor-tracking pattern.
Basic Usage#
Construct a SubgraphBuilder with the target graph and the anchor node —
the node immediately before where you want to start inserting. Then use it
as a callable inside a with block to add nodes one at a time:
from torch_tensorrt.dynamo.lowering._SubgraphBuilder import SubgraphBuilder
import torch.ops.aten as aten
# Inside a lowering pass, given a node `mul_node` to replace:
with SubgraphBuilder(gm.graph, mul_node) as b:
# Each call inserts a node after the current cursor and advances it.
re_a = b(aten.select.int, a, -1, 0) # a[..., 0] (real part of a)
im_a = b(aten.select.int, a, -1, 1) # a[..., 1] (imag part of a)
re_b = b(aten.select.int, b_node, -1, 0)
im_b = b(aten.select.int, b_node, -1, 1)
real = b(aten.sub.Tensor, b(aten.mul.Tensor, re_a, re_b),
b(aten.mul.Tensor, im_a, im_b)) # ac - bd
imag = b(aten.add.Tensor, b(aten.mul.Tensor, re_a, im_b),
b(aten.mul.Tensor, im_a, re_b)) # ad + bc
result = b(aten.stack, [real, imag], -1)
mul_node.replace_all_uses_with(result)
gm.graph.erase_node(mul_node)
On __exit__, the builder automatically calls graph.lint() to validate
the modified graph. If your code raises an exception inside the block, the
lint is skipped so you see the original error rather than a secondary graph
validation failure.
How It Works#
The builder maintains a cursor — initially the anchor node passed to
__init__. Every time you call it:
A new
call_functionnode is inserted viagraph.inserting_after(cursor).The cursor advances to the newly inserted node.
The new node is appended to an internal
_insertedlist for debug logging.
This ensures that successive calls produce a correctly ordered chain:
anchor → node_0 → node_1 → node_2 → ...
without any manual bookkeeping.
Debug Logging#
When the torch_tensorrt logger is set to DEBUG, the builder emits a
compact summary of all inserted nodes after a successful block, for example:
rewrite %mul_17[(4, 32, 2),torch.float32] ->
%select_72[(4, 32),torch.float32] = select_int(%inp_0, -1, 0)
%select_73[(4, 32),torch.float32] = select_int(%inp_0, -1, 1)
%mul_18[(4, 32),torch.float32] = mul_Tensor(%select_72, %select_73)
...
This makes it easy to trace exactly which nodes were produced by a particular rewrite rule.
API Reference#
- class torch_tensorrt.dynamo.lowering._SubgraphBuilder.SubgraphBuilder(graph: Graph, cursor: Node)[source]#
Cursor-based helper for inserting a sequence of FX
call_functionnodes.Construct it with the graph and an anchor node, then call it like a function to append each new node immediately after the current cursor:
with SubgraphBuilder(graph, node) as b: re = b(aten.select.int, inp, -1, 0) im = b(aten.select.int, inp, -1, 1) out = b(aten.add.Tensor, re, im)
Each call inserts one
call_functionnode right after the cursor and advances the cursor to that node. Scalar / list arguments are forwarded as-is.On
__exit__the graph is linted to catch any malformed nodes inserted during the block. Exceptions from user code propagate normally; lint errors are only raised when the block itself succeeds.- property cursor: Node#
When to Use SubgraphBuilder#
Use SubgraphBuilder whenever a lowering pass needs to expand one node into
a sequence of several nodes in a single linear chain. Typical use cases:
Replacing a complex-arithmetic op with real-arithmetic equivalents (e.g. the
complex_mul_replacementin Complex Number Support).Decomposing a high-level op (e.g.
layer_norm) into its ATen primitives when a custom replacement strategy is needed beyond the standard decomposition table.Inserting diagnostic nodes (shape probes, debug prints) around a target op.
If you only need to insert a single node, a plain
graph.inserting_after(node) is simpler. If you need to insert into multiple
disconnected locations in the same pass, create a separate SubgraphBuilder
for each anchor.