# Where to apply torch.compile? We recommend applying `torch.compile` to the highest-level function that doesn’t cause excessive problems. Typically, it is: - your `train` or `eval` step with the optimizer but without the loop, - your top-level `nn.Module` - or some sub-`nn.Module`s. `torch.compile` specifically doesn’t handle distributed wrapper modules like DDP or FSDP very well, so consider applying `torch.compile` to the inner module passed to the wrapper. ```python # inference model = ... model.compile() for _ in range(N_ITERS): inp = ... out = model(inp) ``` ```python # training model = ... opt = torch.optim.Adam(model.parameters()) @torch.compile def train(mod, data): opt.zero_grad(True) pred = mod(data[0]) loss = torch.nn.CrossEntropyLoss()(pred, data[1]) loss.backward() opt.step() for _ in range(N_ITERS): inp = ... train(model, inp) ``` ```python # DistributedDataParallel model = ... model.compile() model_ddp = DistributedDataParallel(model, ...) for _ in range(N_ITERS): inp = ... out = model_ddp(inp) ``` ## `compile(model)` vs `model.compile()` Due to nuances to how `torch.compile` interacts with `nn.Module` instances, we advise using the `.compile()` method of `nn.Module` instances if you wish to compile them as top-level functions. Nested module calls will be traced correctly - there is no need to call `.compile()` in that case. ```python # DO NOT DO THIS model = MyModel() model = torch.compile(model) model(inp) # DO THIS model = MyModel() model.compile() model(inp) # this is also acceptable @torch.compile def fn(model, inp): return model(inp) model = MyModel() fn(model, inp) ```