torch.compiler.nonstrict_trace#
- torch.compiler.nonstrict_trace(traceable_fn)[source]#
Decorator to mark a function as nonstrict-traceable for dynamo.
A nonstrict-traced function appears as an opaque call in the dynamo graph. Dynamo does not trace into the function body (hence the “nonstrict”), but aot_autograd will trace into it.
This is similar to
allow_in_graphbut with enhanced support for: - User-defined classes as inputs (must be registered with pytree) -nn.Moduleas input arguments (parameters and buffers are tracked for autograd) - Global/captured tensors treated as constants (assumed not updated during execution)Note
With
backend="eager", the original Python function runs directly. Withbackend="aot_eager", the graph traced by aot_autograd runs. Withbackend="inductor", the traced graph is compiled with inductor.Training is supported: you can call
.backward()on outputs and gradients will flow through the nonstrict-traced function.
- Dangerous patterns (may cause silent incorrectness):
Side effects between nonstric_trace’d fn and compiled region: The function should not depend on variables mutated by other code inside the compiled function, and code after the call should not depend on mutations made by it.
Implicit inputs (closures/globals): Tensors captured from enclosing scopes are treated as constants. Gradients will NOT flow back to them. Pass tensors as explicit arguments if gradients are needed.
- Restrictions:
Both inputs and outputs must use pytree-compatible types. User-defined classes must be registered via
torch.utils._pytree.register_pytree_node(),torch.utils._pytree.register_dataclass(), ortorch.utils._pytree.register_constant(). Tensors, Python primitives (int, float, bool, str), symbolic types (SymInt, SymFloat, SymBool), and built-in containers (list, tuple, dict) are already handled by default.Primitive values and container structure are specialized per call site: each call site expects the same primitives and structure on every execution.
Example:
>>> import torch >>> @torch.compiler.nonstrict_trace ... def traced_forward(model, x): ... # It's OK to have dynamo graph break within nonstrict_trace region ... torch._dynamo.graph_break() ... return model(x) + x ... >>> class MyModule(torch.nn.Module): ... def __init__(self): ... super().__init__() ... self.inner = torch.nn.Linear(10, 10) ... ... def forward(self, x): ... return traced_forward(self.inner, x) ... >>> # Compile and run >>> model = MyModule() >>> opt_model = torch.compile(model, backend="aot_eager", fullgraph=True) >>> out = opt_model(torch.randn(10, 10)) >>> out.sum().backward() # Gradients flow through traced_forward
- Return type:
Callable[[~_P], _R]