```{code-cell} :tags: [remove-cell] import torch import header_code torch._logging.set_logs(recompiles=True) ``` # Dealing with Recompilations Recompilations are necessary for `torch.compile` soundness, but can result in significantly increased compile time. Thus, minimizing recompilations while preserving soundness is essential for reducing compile time. You can view recompilations and their reasons using tlparse or `TORCH_LOGS=recompiles`. ## Is Dynamic Shapes Enabled? In the below example, we recompile due to mismatched shapes: ```{code-cell} @torch.compile def fn(x): return x + 1 fn(torch.ones(3)) fn(torch.ones(4)) ``` Make sure that the dynamic option of `torch.compile` is not set to `False`. The default option, `dynamic=None`, will only attempt dynamic shapes after the first compilation. You can set `dynamic=True` to upfront compile as dynamic as possible: ```{code-cell} @torch.compile(dynamic=True) def gn(x): return x + 1 gn(torch.ones(3)) gn(torch.ones(4)) ``` For more information on dynamic shapes, including dealing with errors/recompilations due to dynamic shapes, see [the dynamic shapes manual](https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit?tab=t.0#heading=h.fh8zzonyw8ng). ## Wrapping Constants with Tensors By default, `int` / `float` variables are treated as constants and are guarded on their exact value. In the below example, we have a recompilation for each function call. ```{code-cell} @torch.compile def fn(x, c): return x + c for i in range(5): fn(torch.ones(i), 0.5 + i) ``` In particular, for LR schedulers, initializing with a constant can lead to recompilations: ```{code-cell} mod = torch.nn.Linear(3, 3) opt = torch.optim.Adam(mod.parameters(), lr=0.01) sched = torch.optim.lr_scheduler.ExponentialLR(opt, 0.9) @torch.compile def gn(inp): opt.zero_grad(True) out = mod(inp).sum() out.backward() opt.step() sched.step() for i in range(5): gn(torch.ones(3, 3)) ``` In both examples, we can wrap `float` variables in tensors in order to prevent recompilations. ```{code-cell} :tags: [remove-cell] torch._dynamo.reset() ``` ```{code-cell} # first example for i in range(5): fn(torch.ones(i), torch.tensor(0.5 + i)) # second example opt = torch.optim.Adam(mod.parameters(), lr=torch.tensor(0.01)) sched = torch.optim.lr_scheduler.ExponentialLR(opt, torch.tensor(0.9)) for i in range(5): gn(torch.ones(3, 3)) ``` (programming_model.recompilation.changing_cache_size_limit)= ## Changing the Cache Size Limit There is a limit to how many times a function can be recompiled, determined by `torch._dynamo.config.cache_size_limit` and `torch._dynamo.config.accumulated_cache_size_limit` (The exact difference between these 2 values is detailed in [`torch/_dynamo/cache_size.py`](https://github.com/pytorch/pytorch/blob/4ce6e6ec8890a3f6ee604c9efb3ff153825ce575/torch/_dynamo/cache_size.py#L14)). If the Dynamo cache limit is hit, then all future compilation attempts **will result in the function being skipped (run eagerly)**. Dynamo will still attempt to use previously compiled bytecode for future function calls, if the guards pass. Note that in the case of a recompilation limit hit, **all nested function calls WILL be skipped** (Dynamo will try to use previously compiled bytecode for the nested functions). Dynamo will also issue a warning containing the affected function and which limit was hit. In the example below, each function call results in a recompile attempt. When we hit the cache size limit (by default, 8), we stop attempting to recompile. (Note that we set `dynamic=False` for demonstration purposes to force recompilation every time). ```{code-cell} @torch.compile(dynamic=False) def fn(x): return x + 1 for i in range(1, 10): # recompile every time due to dynamic=False fn(torch.ones(i)) ``` If you know that the number of recompilations has a reasonable constant upper bound, you can raise the cache size limit. If the cost of recompilation outweighs the benefit of compilation, then you can consider lowering the cache size limit. ```{code-cell} torch._dynamo.config.cache_size_limit = 16 @torch.compile(dynamic=False) def gn(x): return x + 1 for i in range(1, 10): gn(torch.ones(i)) ``` ## Graph Breaking to Reduce Recompilation Costs If a large graph is recompiling and causing high compile time, you can intentionally introduce a graph break in order to reduce recompilation costs, at the expense of introducing a performance hit. ```{code-cell} def very_large_function(x): return x + 1 @torch.compile(dynamic=False) def fn(x, c): y = very_large_function(x) # recompiled every time return y + c for i in range(1, 5): fn(torch.ones(3), i) @torch.compile(dynamic=False) def gn(x, c): y = very_large_function(x) # compiled only once torch._dynamo.graph_break() return y + c # recompiled every time for i in range(1, 5): gn(torch.ones(3), i) ```