Note
Go to the end to download the full example code.
(beta) Running the compiled optimizer with an LR Scheduler#
Created On: May 21, 2024 | Last Updated: May 21, 2024 | Last Verified: Nov 05, 2024
Author: Michael Lazos
The optimizer is a key algorithm for training any deep learning model.
In this example, we will show how to pair the optimizer, which has been compiled using torch.compile
,
with the LR schedulers to accelerate training convergence.
Note
This tutorial requires PyTorch 2.3.0 or later.
Model Setup#
For this example, we’ll use a simple sequence of linear layers.
import torch
# Create simple model
model = torch.nn.Sequential(
*[torch.nn.Linear(1024, 1024, False, device="cuda") for _ in range(10)]
)
input = torch.rand(1024, device="cuda")
# run forward pass
output = model(input)
# run backward to populate the grads for our optimizer below
output.sum().backward()
Setting up and running the compiled optimizer with LR Scheduler#
In this section, we’ll use the Adam optimizer with LinearLR Scheduler
and create a helper function to wrap the step()
call for each of them
in torch.compile()
.
Note
torch.compile
is only supported on CUDA devices that have a compute capability of 7.0 or higher.
# exit cleanly if we are on a device that doesn't support ``torch.compile``
if torch.cuda.get_device_capability() < (7, 0):
print("Exiting because torch.compile is not supported on this device.")
import sys
sys.exit(0)
# !!! IMPORTANT !!! Wrap the lr in a Tensor if we are pairing the
# the optimizer with an LR Scheduler.
# Without this, torch.compile will recompile as the value of the LR
# changes.
opt = torch.optim.Adam(model.parameters(), lr=torch.tensor(0.01))
sched = torch.optim.lr_scheduler.LinearLR(opt, total_iters=5)
@torch.compile(fullgraph=False)
def fn():
opt.step()
sched.step()
# Warmup runs to compile the function
for _ in range(5):
fn()
print(opt.param_groups[0]["lr"])
tensor(0.0047)
tensor(0.0060)
tensor(0.0073)
tensor(0.0087)
tensor(0.0100)
Extension: What happens with a non-tensor LR?#
For the curious, we will show how to peek into what happens with torch.compile
when we don’t wrap the
LR in a tensor.
# No longer wrap the LR in a tensor here
opt = torch.optim.Adam(model.parameters(), lr=0.01)
sched = torch.optim.lr_scheduler.LinearLR(opt, total_iters=5)
@torch.compile(fullgraph=False)
def fn():
opt.step()
sched.step()
# Setup logging to view recompiles
torch._logging.set_logs(recompiles=True)
# Warmup runs to compile the function
# We will now recompile on each iteration
# as the value of the lr is mutated.
for _ in range(5):
fn()
V0801 20:16:15.498000 30642 torch/_dynamo/guards.py:2997] [1/1] [__recompiles] Recompiling function wrapper in /usr/local/lib/python3.10/dist-packages/torch/optim/optimizer.py:465
V0801 20:16:15.498000 30642 torch/_dynamo/guards.py:2997] [1/1] [__recompiles] triggered by the following guard failure(s):
V0801 20:16:15.498000 30642 torch/_dynamo/guards.py:2997] [1/1] [__recompiles] - 1/0: Cache line invalidated because L['args'][0] got deallocated
V0801 20:16:15.513000 30642 torch/_dynamo/guards.py:2997] [2/1] [__recompiles] Recompiling function step in /usr/local/lib/python3.10/dist-packages/torch/optim/adam.py:212
V0801 20:16:15.513000 30642 torch/_dynamo/guards.py:2997] [2/1] [__recompiles] triggered by the following guard failure(s):
V0801 20:16:15.513000 30642 torch/_dynamo/guards.py:2997] [2/1] [__recompiles] - 2/0: Cache line invalidated because L['self'] got deallocated
V0801 20:16:18.590000 30642 torch/_dynamo/guards.py:2997] [2/2] [__recompiles] Recompiling function step in /usr/local/lib/python3.10/dist-packages/torch/optim/adam.py:212
V0801 20:16:18.590000 30642 torch/_dynamo/guards.py:2997] [2/2] [__recompiles] triggered by the following guard failure(s):
V0801 20:16:18.590000 30642 torch/_dynamo/guards.py:2997] [2/2] [__recompiles] - 2/1: ___as_tensor(self.param_groups[0]['lr']).item() == 0.003333333333333333 # (unknown source ___as_tensor(self.param_groups[0]['lr']).item(), please file a bug)
V0801 20:16:18.590000 30642 torch/_dynamo/guards.py:2997] [2/2] [__recompiles] - 2/0: Cache line invalidated because L['self'] got deallocated
V0801 20:16:20.635000 30642 torch/_dynamo/guards.py:2997] [2/3] [__recompiles] Recompiling function step in /usr/local/lib/python3.10/dist-packages/torch/optim/adam.py:212
V0801 20:16:20.635000 30642 torch/_dynamo/guards.py:2997] [2/3] [__recompiles] triggered by the following guard failure(s):
V0801 20:16:20.635000 30642 torch/_dynamo/guards.py:2997] [2/3] [__recompiles] - 2/2: ___as_tensor(self.param_groups[0]['lr']).item() == 0.004666666666666667 # (unknown source ___as_tensor(self.param_groups[0]['lr']).item(), please file a bug)
V0801 20:16:20.635000 30642 torch/_dynamo/guards.py:2997] [2/3] [__recompiles] - 2/1: ___as_tensor(self.param_groups[0]['lr']).item() == 0.003333333333333333 # (unknown source ___as_tensor(self.param_groups[0]['lr']).item(), please file a bug)
V0801 20:16:20.635000 30642 torch/_dynamo/guards.py:2997] [2/3] [__recompiles] - 2/0: Cache line invalidated because L['self'] got deallocated
V0801 20:16:22.677000 30642 torch/_dynamo/guards.py:2997] [2/4] [__recompiles] Recompiling function step in /usr/local/lib/python3.10/dist-packages/torch/optim/adam.py:212
V0801 20:16:22.677000 30642 torch/_dynamo/guards.py:2997] [2/4] [__recompiles] triggered by the following guard failure(s):
V0801 20:16:22.677000 30642 torch/_dynamo/guards.py:2997] [2/4] [__recompiles] - 2/3: ___as_tensor(self.param_groups[0]['lr']).item() == 0.006000000000000001 # (unknown source ___as_tensor(self.param_groups[0]['lr']).item(), please file a bug)
V0801 20:16:22.677000 30642 torch/_dynamo/guards.py:2997] [2/4] [__recompiles] - 2/2: ___as_tensor(self.param_groups[0]['lr']).item() == 0.004666666666666667 # (unknown source ___as_tensor(self.param_groups[0]['lr']).item(), please file a bug)
V0801 20:16:22.677000 30642 torch/_dynamo/guards.py:2997] [2/4] [__recompiles] - 2/1: ___as_tensor(self.param_groups[0]['lr']).item() == 0.003333333333333333 # (unknown source ___as_tensor(self.param_groups[0]['lr']).item(), please file a bug)
V0801 20:16:22.677000 30642 torch/_dynamo/guards.py:2997] [2/4] [__recompiles] - 2/0: Cache line invalidated because L['self'] got deallocated
V0801 20:16:24.956000 30642 torch/_dynamo/guards.py:2997] [2/5] [__recompiles] Recompiling function step in /usr/local/lib/python3.10/dist-packages/torch/optim/adam.py:212
V0801 20:16:24.956000 30642 torch/_dynamo/guards.py:2997] [2/5] [__recompiles] triggered by the following guard failure(s):
V0801 20:16:24.956000 30642 torch/_dynamo/guards.py:2997] [2/5] [__recompiles] - 2/4: ___as_tensor(self.param_groups[0]['lr']).item() == 0.007333333333333335 # (unknown source ___as_tensor(self.param_groups[0]['lr']).item(), please file a bug)
V0801 20:16:24.956000 30642 torch/_dynamo/guards.py:2997] [2/5] [__recompiles] - 2/3: ___as_tensor(self.param_groups[0]['lr']).item() == 0.006000000000000001 # (unknown source ___as_tensor(self.param_groups[0]['lr']).item(), please file a bug)
V0801 20:16:24.956000 30642 torch/_dynamo/guards.py:2997] [2/5] [__recompiles] - 2/2: ___as_tensor(self.param_groups[0]['lr']).item() == 0.004666666666666667 # (unknown source ___as_tensor(self.param_groups[0]['lr']).item(), please file a bug)
V0801 20:16:24.956000 30642 torch/_dynamo/guards.py:2997] [2/5] [__recompiles] - 2/1: ___as_tensor(self.param_groups[0]['lr']).item() == 0.003333333333333333 # (unknown source ___as_tensor(self.param_groups[0]['lr']).item(), please file a bug)
V0801 20:16:24.956000 30642 torch/_dynamo/guards.py:2997] [2/5] [__recompiles] - 2/0: Cache line invalidated because L['self'] got deallocated
With this example, we can see that we recompile the optimizer a few times
due to the guard failure on the lr
in param_groups[0]
.
Conclusion#
In this tutorial we showed how to pair the optimizer compiled with torch.compile
with an LR Scheduler to accelerate training convergence. We used a model consisting
of a simple sequence of linear layers with the Adam optimizer paired
with a LinearLR scheduler to demonstrate the LR changing across iterations.
See also:
Compiled optimizer tutorial - an intro into the compiled optimizer.
Compiling the optimizer with PT2 - deeper technical details on the compiled optimizer.
Total running time of the script: (0 minutes 16.061 seconds)