Rate this Page
fullgraph=False">

Where to apply torch.compile?#

Created On: Jul 28, 2025 | Last Updated On: Jul 28, 2025

We recommend applying torch.compile to the highest-level function that doesn’t cause excessive problems. Typically, it is:

  • your train or eval step with the optimizer but without the loop,

  • your top-level nn.Module

  • or some sub-nn.Modules.

torch.compile specifically doesn’t handle distributed wrapper modules like DDP or FSDP very well, so consider applying torch.compile to the inner module passed to the wrapper.

# inference
model = ...
model.compile()

for _ in range(N_ITERS):
    inp = ...
    out = model(inp)
# training
model = ...
opt = torch.optim.Adam(model.parameters())

@torch.compile
def train(mod, data):
    opt.zero_grad(True)
    pred = mod(data[0])
    loss = torch.nn.CrossEntropyLoss()(pred, data[1])
    loss.backward()
    opt.step()

for _ in range(N_ITERS):
    inp = ...
    train(model, inp)
# DistributedDataParallel
model = ...
model.compile()
model_ddp = DistributedDataParallel(model, ...)

for _ in range(N_ITERS):
    inp = ...
    out = model_ddp(inp)

compile(model) vs model.compile()#

Due to nuances to how torch.compile interacts with nn.Module instances, we advise using the .compile() method of nn.Module instances if you wish to compile them as top-level functions. Nested module calls will be traced correctly - there is no need to call .compile() in that case.

# DO NOT DO THIS
model = MyModel()
model = torch.compile(model)
model(inp)

# DO THIS
model = MyModel()
model.compile()
model(inp)

# this is also acceptable
@torch.compile
def fn(model, inp):
    return model(inp)
model = MyModel()
fn(model, inp)