functorch.compile.aot_module¶
-
functorch.compile.aot_module(mod, *args, **kwargs)[source]¶ Traces the forward and backward graph of
modusing torch dispatch tracing mechanism. It is wrapper function, that underneath usesaot_function()to perform tracing and compilation.aot_module()lifts the parameters and buffers ofnn.Moduleas inputs to a new callable which is then compiled throughaot_function().Warning
This API is experimental and likely to change.
- Parameters
mod (Callable) – A
nn.Modulemodule.args – args to be passed to
aot_function()kwargs – kwargs to be passed to
aot_function()
- Returns
Returns a
nn.Modulethat retains the eager behavior of the originalmod, but with forward and backward graph compiled.