Note
Go to the end to download the full example code.
Introduction || Tensors || Autograd || Building Models || TensorBoard Support || Training Models || Model Understanding
Training with PyTorch#
Created On: Nov 30, 2021 | Last Updated: May 31, 2023 | Last Verified: Nov 05, 2024
Follow along with the video below or on youtube.
Introduction#
In past videos, we’ve discussed and demonstrated:
Building models with the neural network layers and functions of the torch.nn module
The mechanics of automated gradient computation, which is central to gradient-based model training
Using TensorBoard to visualize training progress and other activities
In this video, we’ll be adding some new tools to your inventory:
We’ll get familiar with the dataset and dataloader abstractions, and how they ease the process of feeding data to your model during a training loop
We’ll discuss specific loss functions and when to use them
We’ll look at PyTorch optimizers, which implement algorithms to adjust model weights based on the outcome of a loss function
Finally, we’ll pull all of these together and see a full PyTorch training loop in action.
Dataset and DataLoader#
The Dataset
and DataLoader
classes encapsulate the process of
pulling your data from storage and exposing it to your training loop in
batches.
The Dataset
is responsible for accessing and processing single
instances of data.
The DataLoader
pulls instances of data from the Dataset
(either
automatically or with a sampler that you define), collects them in
batches, and returns them for consumption by your training loop. The
DataLoader
works with all kinds of datasets, regardless of the type
of data they contain.
For this tutorial, we’ll be using the Fashion-MNIST dataset provided by
TorchVision. We use torchvision.transforms.Normalize()
to
zero-center and normalize the distribution of the image tile content,
and download both training and validation data splits.
import torch
import torchvision
import torchvision.transforms as transforms
# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))])
# Create datasets for training & validation, download if necessary
training_set = torchvision.datasets.FashionMNIST('./data', train=True, transform=transform, download=True)
validation_set = torchvision.datasets.FashionMNIST('./data', train=False, transform=transform, download=True)
# Create data loaders for our datasets; shuffle for training, not for validation
training_loader = torch.utils.data.DataLoader(training_set, batch_size=4, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=4, shuffle=False)
# Class labels
classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')
# Report split sizes
print('Training set has {} instances'.format(len(training_set)))
print('Validation set has {} instances'.format(len(validation_set)))
0%| | 0.00/26.4M [00:00<?, ?B/s]
0%| | 65.5k/26.4M [00:00<01:12, 364kB/s]
1%| | 197k/26.4M [00:00<00:45, 577kB/s]
3%|▎ | 852k/26.4M [00:00<00:13, 1.97MB/s]
13%|█▎ | 3.38M/26.4M [00:00<00:03, 6.69MB/s]
21%|██ | 5.51M/26.4M [00:00<00:02, 8.50MB/s]
39%|███▉ | 10.4M/26.4M [00:01<00:01, 14.8MB/s]
61%|██████ | 16.1M/26.4M [00:01<00:00, 23.7MB/s]
74%|███████▍ | 19.5M/26.4M [00:01<00:00, 22.4MB/s]
92%|█████████▏| 24.3M/26.4M [00:01<00:00, 28.0MB/s]
100%|██████████| 26.4M/26.4M [00:01<00:00, 17.1MB/s]
0%| | 0.00/29.5k [00:00<?, ?B/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 328kB/s]
0%| | 0.00/4.42M [00:00<?, ?B/s]
1%|▏ | 65.5k/4.42M [00:00<00:12, 363kB/s]
4%|▎ | 164k/4.42M [00:00<00:09, 470kB/s]
16%|█▋ | 721k/4.42M [00:00<00:02, 1.66MB/s]
65%|██████▌ | 2.88M/4.42M [00:00<00:00, 5.73MB/s]
100%|██████████| 4.42M/4.42M [00:00<00:00, 6.09MB/s]
0%| | 0.00/5.15k [00:00<?, ?B/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 48.6MB/s]
Training set has 60000 instances
Validation set has 10000 instances
As always, let’s visualize the data as a sanity check:
import matplotlib.pyplot as plt
import numpy as np
# Helper function for inline image display
def matplotlib_imshow(img, one_channel=False):
if one_channel:
img = img.mean(dim=0)
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
if one_channel:
plt.imshow(npimg, cmap="Greys")
else:
plt.imshow(np.transpose(npimg, (1, 2, 0)))
dataiter = iter(training_loader)
images, labels = next(dataiter)
# Create a grid from the images and show them
img_grid = torchvision.utils.make_grid(images)
matplotlib_imshow(img_grid, one_channel=True)
print(' '.join(classes[labels[j]] for j in range(4)))

T-shirt/top Dress Bag Trouser
The Model#
The model we’ll use in this example is a variant of LeNet-5 - it should be familiar if you’ve watched the previous videos in this series.
import torch.nn as nn
import torch.nn.functional as F
# PyTorch models inherit from torch.nn.Module
class GarmentClassifier(nn.Module):
def __init__(self):
super(GarmentClassifier, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 4 * 4, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 4 * 4)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
model = GarmentClassifier()
Loss Function#
For this example, we’ll be using a cross-entropy loss. For demonstration purposes, we’ll create batches of dummy output and label values, run them through the loss function, and examine the result.
loss_fn = torch.nn.CrossEntropyLoss()
# NB: Loss functions expect data in batches, so we're creating batches of 4
# Represents the model's confidence in each of the 10 classes for a given input
dummy_outputs = torch.rand(4, 10)
# Represents the correct class among the 10 being tested
dummy_labels = torch.tensor([1, 5, 3, 7])
print(dummy_outputs)
print(dummy_labels)
loss = loss_fn(dummy_outputs, dummy_labels)
print('Total loss for this batch: {}'.format(loss.item()))
tensor([[0.1764, 0.1619, 0.8875, 0.1805, 0.7036, 0.3358, 0.6931, 0.0364, 0.8430,
0.0767],
[0.0182, 0.0407, 0.6103, 0.2460, 0.9907, 0.0978, 0.3805, 0.2288, 0.7804,
0.3488],
[0.5060, 0.9863, 0.2487, 0.8829, 0.4201, 0.4798, 0.3157, 0.4678, 0.8579,
0.3824],
[0.0289, 0.1472, 0.3540, 0.7207, 0.7277, 0.7830, 0.7991, 0.6037, 0.3332,
0.2823]])
tensor([1, 5, 3, 7])
Total loss for this batch: 2.362215042114258
Optimizer#
For this example, we’ll be using simple stochastic gradient descent with momentum.
It can be instructive to try some variations on this optimization scheme:
Learning rate determines the size of the steps the optimizer takes. What does a different learning rate do to the your training results, in terms of accuracy and convergence time?
Momentum nudges the optimizer in the direction of strongest gradient over multiple steps. What does changing this value do to your results?
Try some different optimization algorithms, such as averaged SGD, Adagrad, or Adam. How do your results differ?
# Optimizers specified in the torch.optim package
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
The Training Loop#
Below, we have a function that performs one training epoch. It enumerates data from the DataLoader, and on each pass of the loop does the following:
Gets a batch of training data from the DataLoader
Zeros the optimizer’s gradients
Performs an inference - that is, gets predictions from the model for an input batch
Calculates the loss for that set of predictions vs. the labels on the dataset
Calculates the backward gradients over the learning weights
Tells the optimizer to perform one learning step - that is, adjust the model’s learning weights based on the observed gradients for this batch, according to the optimization algorithm we chose
It reports on the loss for every 1000 batches.
Finally, it reports the average per-batch loss for the last 1000 batches, for comparison with a validation run
def train_one_epoch(epoch_index, tb_writer):
running_loss = 0.
last_loss = 0.
# Here, we use enumerate(training_loader) instead of
# iter(training_loader) so that we can track the batch
# index and do some intra-epoch reporting
for i, data in enumerate(training_loader):
# Every data instance is an input + label pair
inputs, labels = data
# Zero your gradients for every batch!
optimizer.zero_grad()
# Make predictions for this batch
outputs = model(inputs)
# Compute the loss and its gradients
loss = loss_fn(outputs, labels)
loss.backward()
# Adjust learning weights
optimizer.step()
# Gather data and report
running_loss += loss.item()
if i % 1000 == 999:
last_loss = running_loss / 1000 # loss per batch
print(' batch {} loss: {}'.format(i + 1, last_loss))
tb_x = epoch_index * len(training_loader) + i + 1
tb_writer.add_scalar('Loss/train', last_loss, tb_x)
running_loss = 0.
return last_loss
Per-Epoch Activity#
There are a couple of things we’ll want to do once per epoch:
Perform validation by checking our relative loss on a set of data that was not used for training, and report this
Save a copy of the model
Here, we’ll do our reporting in TensorBoard. This will require going to the command line to start TensorBoard, and opening it in another browser tab.
# Initializing in a separate cell so we can easily add more epochs to the same run
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp))
epoch_number = 0
EPOCHS = 5
best_vloss = 1_000_000.
for epoch in range(EPOCHS):
print('EPOCH {}:'.format(epoch_number + 1))
# Make sure gradient tracking is on, and do a pass over the data
model.train(True)
avg_loss = train_one_epoch(epoch_number, writer)
running_vloss = 0.0
# Set the model to evaluation mode, disabling dropout and using population
# statistics for batch normalization.
model.eval()
# Disable gradient computation and reduce memory consumption.
with torch.no_grad():
for i, vdata in enumerate(validation_loader):
vinputs, vlabels = vdata
voutputs = model(vinputs)
vloss = loss_fn(voutputs, vlabels)
running_vloss += vloss
avg_vloss = running_vloss / (i + 1)
print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))
# Log the running loss averaged per batch
# for both training and validation
writer.add_scalars('Training vs. Validation Loss',
{ 'Training' : avg_loss, 'Validation' : avg_vloss },
epoch_number + 1)
writer.flush()
# Track best performance, and save the model's state
if avg_vloss < best_vloss:
best_vloss = avg_vloss
model_path = 'model_{}_{}'.format(timestamp, epoch_number)
torch.save(model.state_dict(), model_path)
epoch_number += 1
EPOCH 1:
batch 1000 loss: 1.9988314672708511
batch 2000 loss: 0.9193895362205804
batch 3000 loss: 0.7223440953306853
batch 4000 loss: 0.6701288196891546
batch 5000 loss: 0.581597259292379
batch 6000 loss: 0.5741387488790788
batch 7000 loss: 0.5606894336201949
batch 8000 loss: 0.5268562002686085
batch 9000 loss: 0.4830455190619687
batch 10000 loss: 0.4596923828284489
batch 11000 loss: 0.49248371993168255
batch 12000 loss: 0.44985741829127074
batch 13000 loss: 0.46217175873881206
batch 14000 loss: 0.43543626031989696
batch 15000 loss: 0.43752258202363736
LOSS train 0.43752258202363736 valid 0.430812269449234
EPOCH 2:
batch 1000 loss: 0.4037419849329745
batch 2000 loss: 0.39954532159899825
batch 3000 loss: 0.40559865027916386
batch 4000 loss: 0.3833038282005873
batch 5000 loss: 0.3906617384759011
batch 6000 loss: 0.37513707852845984
batch 7000 loss: 0.37699090542938213
batch 8000 loss: 0.3761951707657863
batch 9000 loss: 0.36876384827977743
batch 10000 loss: 0.3674912270232453
batch 11000 loss: 0.3550953941930784
batch 12000 loss: 0.3561198842830054
batch 13000 loss: 0.35298089532964516
batch 14000 loss: 0.36295080438960575
batch 15000 loss: 0.35744707003689835
LOSS train 0.35744707003689835 valid 0.36409085988998413
EPOCH 3:
batch 1000 loss: 0.33181712135783165
batch 2000 loss: 0.3163912760570238
batch 3000 loss: 0.34046520160492216
batch 4000 loss: 0.33481437444103357
batch 5000 loss: 0.3392046003730993
batch 6000 loss: 0.3134670601064281
batch 7000 loss: 0.3283420227258321
batch 8000 loss: 0.31831632618168076
batch 9000 loss: 0.3080377133266884
batch 10000 loss: 0.3223431621780182
batch 11000 loss: 0.316907833893878
batch 12000 loss: 0.3190856933850737
batch 13000 loss: 0.3103358497906811
batch 14000 loss: 0.3308342638386093
batch 15000 loss: 0.3177569979402469
LOSS train 0.3177569979402469 valid 0.35152819752693176
EPOCH 4:
batch 1000 loss: 0.27613517147956007
batch 2000 loss: 0.3150471937338443
batch 3000 loss: 0.28435196808750696
batch 4000 loss: 0.2808161023599605
batch 5000 loss: 0.30603527145737824
batch 6000 loss: 0.30359787086395956
batch 7000 loss: 0.3039901812107855
batch 8000 loss: 0.3200699506045785
batch 9000 loss: 0.3013748074879259
batch 10000 loss: 0.29429625780676544
batch 11000 loss: 0.2937111030754313
batch 12000 loss: 0.28333777568516233
batch 13000 loss: 0.2914983189167979
batch 14000 loss: 0.28703211334972
batch 15000 loss: 0.30034291974519145
LOSS train 0.30034291974519145 valid 0.30364754796028137
EPOCH 5:
batch 1000 loss: 0.27076393893184647
batch 2000 loss: 0.27389251034590417
batch 3000 loss: 0.2679886224323709
batch 4000 loss: 0.27351233246324047
batch 5000 loss: 0.2770088437672639
batch 6000 loss: 0.2842327052370747
batch 7000 loss: 0.26012170134465123
batch 8000 loss: 0.26852404565333793
batch 9000 loss: 0.27785209107911213
batch 10000 loss: 0.2856649016340425
batch 11000 loss: 0.27832065098140446
batch 12000 loss: 0.29101420920167587
batch 13000 loss: 0.27228796556679846
batch 14000 loss: 0.29488918731830743
batch 15000 loss: 0.26893940189028853
LOSS train 0.26893940189028853 valid 0.3395138680934906
To load a saved version of the model:
saved_model = GarmentClassifier()
saved_model.load_state_dict(torch.load(PATH))
Once you’ve loaded the model, it’s ready for whatever you need it for - more training, inference, or analysis.
Note that if your model has constructor parameters that affect model structure, you’ll need to provide them and configure the model identically to the state in which it was saved.
Other Resources#
Docs on the data utilities, including Dataset and DataLoader, at pytorch.org
A note on the use of pinned memory for GPU training
Documentation on the datasets available in TorchVision, TorchText, and TorchAudio
Documentation on the loss functions available in PyTorch
Documentation on the torch.optim package, which includes optimizers and related tools, such as learning rate scheduling
A detailed tutorial on saving and loading models
The Tutorials section of pytorch.org contains tutorials on a broad variety of training tasks, including classification in different domains, generative adversarial networks, reinforcement learning, and more
Total running time of the script: (3 minutes 0.804 seconds)