Note
Go to the end to download the full example code.
Introduction || Tensors || Autograd || Building Models || TensorBoard Support || Training Models || Model Understanding
Training with PyTorch#
Created On: Nov 30, 2021 | Last Updated: May 31, 2023 | Last Verified: Nov 05, 2024
Follow along with the video below or on youtube.
Introduction#
In past videos, we’ve discussed and demonstrated:
Building models with the neural network layers and functions of the torch.nn module
The mechanics of automated gradient computation, which is central to gradient-based model training
Using TensorBoard to visualize training progress and other activities
In this video, we’ll be adding some new tools to your inventory:
We’ll get familiar with the dataset and dataloader abstractions, and how they ease the process of feeding data to your model during a training loop
We’ll discuss specific loss functions and when to use them
We’ll look at PyTorch optimizers, which implement algorithms to adjust model weights based on the outcome of a loss function
Finally, we’ll pull all of these together and see a full PyTorch training loop in action.
Dataset and DataLoader#
The Dataset and DataLoader classes encapsulate the process of
pulling your data from storage and exposing it to your training loop in
batches.
The Dataset is responsible for accessing and processing single
instances of data.
The DataLoader pulls instances of data from the Dataset (either
automatically or with a sampler that you define), collects them in
batches, and returns them for consumption by your training loop. The
DataLoader works with all kinds of datasets, regardless of the type
of data they contain.
For this tutorial, we’ll be using the Fashion-MNIST dataset provided by
TorchVision. We use torchvision.transforms.Normalize() to
zero-center and normalize the distribution of the image tile content,
and download both training and validation data splits.
import torch
import torchvision
import torchvision.transforms as transforms
# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))])
# Create datasets for training & validation, download if necessary
training_set = torchvision.datasets.FashionMNIST('./data', train=True, transform=transform, download=True)
validation_set = torchvision.datasets.FashionMNIST('./data', train=False, transform=transform, download=True)
# Create data loaders for our datasets; shuffle for training, not for validation
training_loader = torch.utils.data.DataLoader(training_set, batch_size=4, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=4, shuffle=False)
# Class labels
classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')
# Report split sizes
print('Training set has {} instances'.format(len(training_set)))
print('Validation set has {} instances'.format(len(validation_set)))
0%| | 0.00/26.4M [00:00<?, ?B/s]
0%| | 65.5k/26.4M [00:00<01:13, 360kB/s]
1%| | 164k/26.4M [00:00<00:56, 464kB/s]
3%|▎ | 721k/26.4M [00:00<00:15, 1.64MB/s]
11%|█ | 2.82M/26.4M [00:00<00:04, 5.51MB/s]
32%|███▏ | 8.59M/26.4M [00:00<00:01, 14.9MB/s]
55%|█████▍ | 14.5M/26.4M [00:01<00:00, 20.8MB/s]
75%|███████▌ | 19.9M/26.4M [00:01<00:00, 27.7MB/s]
89%|████████▊ | 23.4M/26.4M [00:01<00:00, 25.0MB/s]
100%|██████████| 26.4M/26.4M [00:01<00:00, 18.0MB/s]
0%| | 0.00/29.5k [00:00<?, ?B/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 327kB/s]
0%| | 0.00/4.42M [00:00<?, ?B/s]
1%|▏ | 65.5k/4.42M [00:00<00:12, 360kB/s]
5%|▌ | 229k/4.42M [00:00<00:06, 677kB/s]
21%|██ | 918k/4.42M [00:00<00:01, 2.09MB/s]
83%|████████▎ | 3.67M/4.42M [00:00<00:00, 7.22MB/s]
100%|██████████| 4.42M/4.42M [00:00<00:00, 6.05MB/s]
0%| | 0.00/5.15k [00:00<?, ?B/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 57.6MB/s]
Training set has 60000 instances
Validation set has 10000 instances
As always, let’s visualize the data as a sanity check:
import matplotlib.pyplot as plt
import numpy as np
# Helper function for inline image display
def matplotlib_imshow(img, one_channel=False):
if one_channel:
img = img.mean(dim=0)
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
if one_channel:
plt.imshow(npimg, cmap="Greys")
else:
plt.imshow(np.transpose(npimg, (1, 2, 0)))
dataiter = iter(training_loader)
images, labels = next(dataiter)
# Create a grid from the images and show them
img_grid = torchvision.utils.make_grid(images)
matplotlib_imshow(img_grid, one_channel=True)
print(' '.join(classes[labels[j]] for j in range(4)))

T-shirt/top Shirt Pullover Coat
The Model#
The model we’ll use in this example is a variant of LeNet-5 - it should be familiar if you’ve watched the previous videos in this series.
import torch.nn as nn
import torch.nn.functional as F
# PyTorch models inherit from torch.nn.Module
class GarmentClassifier(nn.Module):
def __init__(self):
super(GarmentClassifier, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 4 * 4, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 4 * 4)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
model = GarmentClassifier()
Loss Function#
For this example, we’ll be using a cross-entropy loss. For demonstration purposes, we’ll create batches of dummy output and label values, run them through the loss function, and examine the result.
loss_fn = torch.nn.CrossEntropyLoss()
# NB: Loss functions expect data in batches, so we're creating batches of 4
# Represents the model's confidence in each of the 10 classes for a given input
dummy_outputs = torch.rand(4, 10)
# Represents the correct class among the 10 being tested
dummy_labels = torch.tensor([1, 5, 3, 7])
print(dummy_outputs)
print(dummy_labels)
loss = loss_fn(dummy_outputs, dummy_labels)
print('Total loss for this batch: {}'.format(loss.item()))
tensor([[0.7602, 0.4624, 0.3464, 0.8186, 0.5923, 0.5581, 0.9087, 0.8271, 0.9859,
0.3176],
[0.8673, 0.3265, 0.2883, 0.1268, 0.4297, 0.5847, 0.2480, 0.5701, 0.7214,
0.0967],
[0.7344, 0.0896, 0.5846, 0.6482, 0.2377, 0.0214, 0.2032, 0.6433, 0.5843,
0.1285],
[0.5110, 0.0831, 0.3339, 0.0799, 0.1741, 0.3747, 0.6260, 0.5203, 0.5520,
0.3843]])
tensor([1, 5, 3, 7])
Total loss for this batch: 2.2337210178375244
Optimizer#
For this example, we’ll be using simple stochastic gradient descent with momentum.
It can be instructive to try some variations on this optimization scheme:
Learning rate determines the size of the steps the optimizer takes. What does a different learning rate do to the your training results, in terms of accuracy and convergence time?
Momentum nudges the optimizer in the direction of strongest gradient over multiple steps. What does changing this value do to your results?
Try some different optimization algorithms, such as averaged SGD, Adagrad, or Adam. How do your results differ?
# Optimizers specified in the torch.optim package
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
The Training Loop#
Below, we have a function that performs one training epoch. It enumerates data from the DataLoader, and on each pass of the loop does the following:
Gets a batch of training data from the DataLoader
Zeros the optimizer’s gradients
Performs an inference - that is, gets predictions from the model for an input batch
Calculates the loss for that set of predictions vs. the labels on the dataset
Calculates the backward gradients over the learning weights
Tells the optimizer to perform one learning step - that is, adjust the model’s learning weights based on the observed gradients for this batch, according to the optimization algorithm we chose
It reports on the loss for every 1000 batches.
Finally, it reports the average per-batch loss for the last 1000 batches, for comparison with a validation run
def train_one_epoch(epoch_index, tb_writer):
running_loss = 0.
last_loss = 0.
# Here, we use enumerate(training_loader) instead of
# iter(training_loader) so that we can track the batch
# index and do some intra-epoch reporting
for i, data in enumerate(training_loader):
# Every data instance is an input + label pair
inputs, labels = data
# Zero your gradients for every batch!
optimizer.zero_grad()
# Make predictions for this batch
outputs = model(inputs)
# Compute the loss and its gradients
loss = loss_fn(outputs, labels)
loss.backward()
# Adjust learning weights
optimizer.step()
# Gather data and report
running_loss += loss.item()
if i % 1000 == 999:
last_loss = running_loss / 1000 # loss per batch
print(' batch {} loss: {}'.format(i + 1, last_loss))
tb_x = epoch_index * len(training_loader) + i + 1
tb_writer.add_scalar('Loss/train', last_loss, tb_x)
running_loss = 0.
return last_loss
Per-Epoch Activity#
There are a couple of things we’ll want to do once per epoch:
Perform validation by checking our relative loss on a set of data that was not used for training, and report this
Save a copy of the model
Here, we’ll do our reporting in TensorBoard. This will require going to the command line to start TensorBoard, and opening it in another browser tab.
# Initializing in a separate cell so we can easily add more epochs to the same run
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp))
epoch_number = 0
EPOCHS = 5
best_vloss = 1_000_000.
for epoch in range(EPOCHS):
print('EPOCH {}:'.format(epoch_number + 1))
# Make sure gradient tracking is on, and do a pass over the data
model.train(True)
avg_loss = train_one_epoch(epoch_number, writer)
running_vloss = 0.0
# Set the model to evaluation mode, disabling dropout and using population
# statistics for batch normalization.
model.eval()
# Disable gradient computation and reduce memory consumption.
with torch.no_grad():
for i, vdata in enumerate(validation_loader):
vinputs, vlabels = vdata
voutputs = model(vinputs)
vloss = loss_fn(voutputs, vlabels)
running_vloss += vloss
avg_vloss = running_vloss / (i + 1)
print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))
# Log the running loss averaged per batch
# for both training and validation
writer.add_scalars('Training vs. Validation Loss',
{ 'Training' : avg_loss, 'Validation' : avg_vloss },
epoch_number + 1)
writer.flush()
# Track best performance, and save the model's state
if avg_vloss < best_vloss:
best_vloss = avg_vloss
model_path = 'model_{}_{}'.format(timestamp, epoch_number)
torch.save(model.state_dict(), model_path)
epoch_number += 1
EPOCH 1:
batch 1000 loss: 1.8452949015647173
batch 2000 loss: 0.882945014057681
batch 3000 loss: 0.720454607849475
batch 4000 loss: 0.6669794512391091
batch 5000 loss: 0.6499830534318461
batch 6000 loss: 0.563212547364179
batch 7000 loss: 0.5357104063224979
batch 8000 loss: 0.5023811038926942
batch 9000 loss: 0.4970288263714174
batch 10000 loss: 0.476621840887703
batch 11000 loss: 0.49413163296028506
batch 12000 loss: 0.4508753774598008
batch 13000 loss: 0.4357119759125635
batch 14000 loss: 0.42807682094321353
batch 15000 loss: 0.4479408651069243
LOSS train 0.4479408651069243 valid 0.43124017119407654
EPOCH 2:
batch 1000 loss: 0.4231764554461115
batch 2000 loss: 0.38261955516281887
batch 3000 loss: 0.4019369528376264
batch 4000 loss: 0.3749474552767351
batch 5000 loss: 0.3854320874402765
batch 6000 loss: 0.37020634075585984
batch 7000 loss: 0.3859600237645354
batch 8000 loss: 0.36200138485143546
batch 9000 loss: 0.3617918853437877
batch 10000 loss: 0.3568491576012457
batch 11000 loss: 0.36878882840009464
batch 12000 loss: 0.3715824427801999
batch 13000 loss: 0.35505162963457404
batch 14000 loss: 0.3743606256400526
batch 15000 loss: 0.3615355482221348
LOSS train 0.3615355482221348 valid 0.3869512677192688
EPOCH 3:
batch 1000 loss: 0.3253944525849365
batch 2000 loss: 0.328836536561168
batch 3000 loss: 0.3307937320651399
batch 4000 loss: 0.31718457779462916
batch 5000 loss: 0.3399723867363937
batch 6000 loss: 0.33665812234905024
batch 7000 loss: 0.33489793537881635
batch 8000 loss: 0.2996679348834514
batch 9000 loss: 0.33209410377382304
batch 10000 loss: 0.3273461590104853
batch 11000 loss: 0.32440513863955855
batch 12000 loss: 0.3365717599357886
batch 13000 loss: 0.32703474145127986
batch 14000 loss: 0.31254097997711505
batch 15000 loss: 0.3168364536026056
LOSS train 0.3168364536026056 valid 0.331007719039917
EPOCH 4:
batch 1000 loss: 0.3001225632514197
batch 2000 loss: 0.29959649950818856
batch 3000 loss: 0.305554188158334
batch 4000 loss: 0.2975857935305103
batch 5000 loss: 0.3035833477724227
batch 6000 loss: 0.29808518611676116
batch 7000 loss: 0.29426490620626283
batch 8000 loss: 0.30497063191074997
batch 9000 loss: 0.3241176800803514
batch 10000 loss: 0.3000431436141807
batch 11000 loss: 0.2967678876337013
batch 12000 loss: 0.29103585437865875
batch 13000 loss: 0.2923971119223861
batch 14000 loss: 0.30239957763813435
batch 15000 loss: 0.2951189488622404
LOSS train 0.2951189488622404 valid 0.31961268186569214
EPOCH 5:
batch 1000 loss: 0.2873665358930884
batch 2000 loss: 0.2866385520455806
batch 3000 loss: 0.2682322573630809
batch 4000 loss: 0.2877288664934167
batch 5000 loss: 0.27378226517873555
batch 6000 loss: 0.26666203200165184
batch 7000 loss: 0.2960926393906775
batch 8000 loss: 0.2831014274285844
batch 9000 loss: 0.28433366663263043
batch 10000 loss: 0.27850623577708755
batch 11000 loss: 0.2613752937624522
batch 12000 loss: 0.27932199506648975
batch 13000 loss: 0.27204559067965417
batch 14000 loss: 0.29369351826481216
batch 15000 loss: 0.29557288661976056
LOSS train 0.29557288661976056 valid 0.30977943539619446
To load a saved version of the model:
saved_model = GarmentClassifier()
saved_model.load_state_dict(torch.load(PATH))
Once you’ve loaded the model, it’s ready for whatever you need it for - more training, inference, or analysis.
Note that if your model has constructor parameters that affect model structure, you’ll need to provide them and configure the model identically to the state in which it was saved.
Other Resources#
Docs on the data utilities, including Dataset and DataLoader, at pytorch.org
A note on the use of pinned memory for GPU training
Documentation on the datasets available in TorchVision, TorchText, and TorchAudio
Documentation on the loss functions available in PyTorch
Documentation on the torch.optim package, which includes optimizers and related tools, such as learning rate scheduling
A detailed tutorial on saving and loading models
The Tutorials section of pytorch.org contains tutorials on a broad variety of training tasks, including classification in different domains, generative adversarial networks, reinforcement learning, and more
Total running time of the script: (2 minutes 56.489 seconds)