functorch.compile.memory_efficient_fusion¶
-
functorch.compile.memory_efficient_fusion(fn, static_argnums=None, **kwargs)[source]¶ Wrapper function over
aot_function()andaot_module()to perform memory efficient fusion. It uses themin_cut_rematerialization_partition()partitioner to perform efficient recomputation. It uses NVFuser to compile the generated forward and backward graphs.Warning
This API is experimental and likely to change.
- Parameters
fn (Union[Callable, nn.Module]) – A Python function or a
nn.Modulethat takes one ore more arguments. Must return one or more Tensors.static_argnums (Optional[Tuple[Int]]) – An option tuple of ints to mark the arguments of the function as static.
**kwargs – Any other overrides you want to make to the settings
- Returns
Returns a
Callableornn.Modulethat retains the eager behavior of the originalfn, but whose forward and backward graphs have gone through recomputation optimizations, and the graphs have been compiled with nvfuser.