Note
Go to the end to download the full example code.
Introduction || Tensors || Autograd || Building Models || TensorBoard Support || Training Models || Model Understanding
Training with PyTorch#
Created On: Nov 30, 2021 | Last Updated: May 31, 2023 | Last Verified: Nov 05, 2024
Follow along with the video below or on youtube.
Introduction#
In past videos, we’ve discussed and demonstrated:
Building models with the neural network layers and functions of the torch.nn module
The mechanics of automated gradient computation, which is central to gradient-based model training
Using TensorBoard to visualize training progress and other activities
In this video, we’ll be adding some new tools to your inventory:
We’ll get familiar with the dataset and dataloader abstractions, and how they ease the process of feeding data to your model during a training loop
We’ll discuss specific loss functions and when to use them
We’ll look at PyTorch optimizers, which implement algorithms to adjust model weights based on the outcome of a loss function
Finally, we’ll pull all of these together and see a full PyTorch training loop in action.
Dataset and DataLoader#
The Dataset and DataLoader classes encapsulate the process of
pulling your data from storage and exposing it to your training loop in
batches.
The Dataset is responsible for accessing and processing single
instances of data.
The DataLoader pulls instances of data from the Dataset (either
automatically or with a sampler that you define), collects them in
batches, and returns them for consumption by your training loop. The
DataLoader works with all kinds of datasets, regardless of the type
of data they contain.
For this tutorial, we’ll be using the Fashion-MNIST dataset provided by
TorchVision. We use torchvision.transforms.Normalize() to
zero-center and normalize the distribution of the image tile content,
and download both training and validation data splits.
import torch
import torchvision
import torchvision.transforms as transforms
# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))])
# Create datasets for training & validation, download if necessary
training_set = torchvision.datasets.FashionMNIST('./data', train=True, transform=transform, download=True)
validation_set = torchvision.datasets.FashionMNIST('./data', train=False, transform=transform, download=True)
# Create data loaders for our datasets; shuffle for training, not for validation
training_loader = torch.utils.data.DataLoader(training_set, batch_size=4, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=4, shuffle=False)
# Class labels
classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')
# Report split sizes
print('Training set has {} instances'.format(len(training_set)))
print('Validation set has {} instances'.format(len(validation_set)))
0%| | 0.00/26.4M [00:00<?, ?B/s]
0%| | 65.5k/26.4M [00:00<01:12, 364kB/s]
1%| | 229k/26.4M [00:00<00:38, 683kB/s]
3%|▎ | 918k/26.4M [00:00<00:12, 2.10MB/s]
14%|█▍ | 3.67M/26.4M [00:00<00:03, 7.27MB/s]
37%|███▋ | 9.83M/26.4M [00:00<00:00, 16.8MB/s]
60%|██████ | 15.9M/26.4M [00:01<00:00, 22.6MB/s]
84%|████████▎ | 22.1M/26.4M [00:01<00:00, 26.4MB/s]
100%|██████████| 26.4M/26.4M [00:01<00:00, 19.3MB/s]
0%| | 0.00/29.5k [00:00<?, ?B/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 325kB/s]
0%| | 0.00/4.42M [00:00<?, ?B/s]
1%|▏ | 65.5k/4.42M [00:00<00:12, 354kB/s]
4%|▍ | 197k/4.42M [00:00<00:07, 565kB/s]
19%|█▊ | 819k/4.42M [00:00<00:01, 1.85MB/s]
74%|███████▍ | 3.28M/4.42M [00:00<00:00, 6.38MB/s]
100%|██████████| 4.42M/4.42M [00:00<00:00, 5.97MB/s]
0%| | 0.00/5.15k [00:00<?, ?B/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 58.4MB/s]
Training set has 60000 instances
Validation set has 10000 instances
As always, let’s visualize the data as a sanity check:
import matplotlib.pyplot as plt
import numpy as np
# Helper function for inline image display
def matplotlib_imshow(img, one_channel=False):
if one_channel:
img = img.mean(dim=0)
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
if one_channel:
plt.imshow(npimg, cmap="Greys")
else:
plt.imshow(np.transpose(npimg, (1, 2, 0)))
dataiter = iter(training_loader)
images, labels = next(dataiter)
# Create a grid from the images and show them
img_grid = torchvision.utils.make_grid(images)
matplotlib_imshow(img_grid, one_channel=True)
print(' '.join(classes[labels[j]] for j in range(4)))

T-shirt/top Sneaker Bag Trouser
The Model#
The model we’ll use in this example is a variant of LeNet-5 - it should be familiar if you’ve watched the previous videos in this series.
import torch.nn as nn
import torch.nn.functional as F
# PyTorch models inherit from torch.nn.Module
class GarmentClassifier(nn.Module):
def __init__(self):
super(GarmentClassifier, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 4 * 4, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 4 * 4)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
model = GarmentClassifier()
Loss Function#
For this example, we’ll be using a cross-entropy loss. For demonstration purposes, we’ll create batches of dummy output and label values, run them through the loss function, and examine the result.
loss_fn = torch.nn.CrossEntropyLoss()
# NB: Loss functions expect data in batches, so we're creating batches of 4
# Represents the model's confidence in each of the 10 classes for a given input
dummy_outputs = torch.rand(4, 10)
# Represents the correct class among the 10 being tested
dummy_labels = torch.tensor([1, 5, 3, 7])
print(dummy_outputs)
print(dummy_labels)
loss = loss_fn(dummy_outputs, dummy_labels)
print('Total loss for this batch: {}'.format(loss.item()))
tensor([[0.1165, 0.2322, 0.9360, 0.7127, 0.4667, 0.7428, 0.6576, 0.3341, 0.6885,
0.1012],
[0.2776, 0.4531, 0.4777, 0.0832, 0.1309, 0.8242, 0.6120, 0.2191, 0.1570,
0.4848],
[0.1787, 0.5802, 0.5333, 0.4329, 0.3414, 0.8458, 0.9080, 0.0618, 0.4603,
0.3371],
[0.5715, 0.0899, 0.7778, 0.4740, 0.1736, 0.3869, 0.3015, 0.6124, 0.9694,
0.0482]])
tensor([1, 5, 3, 7])
Total loss for this batch: 2.2564003467559814
Optimizer#
For this example, we’ll be using simple stochastic gradient descent with momentum.
It can be instructive to try some variations on this optimization scheme:
Learning rate determines the size of the steps the optimizer takes. What does a different learning rate do to the your training results, in terms of accuracy and convergence time?
Momentum nudges the optimizer in the direction of strongest gradient over multiple steps. What does changing this value do to your results?
Try some different optimization algorithms, such as averaged SGD, Adagrad, or Adam. How do your results differ?
# Optimizers specified in the torch.optim package
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
The Training Loop#
Below, we have a function that performs one training epoch. It enumerates data from the DataLoader, and on each pass of the loop does the following:
Gets a batch of training data from the DataLoader
Zeros the optimizer’s gradients
Performs an inference - that is, gets predictions from the model for an input batch
Calculates the loss for that set of predictions vs. the labels on the dataset
Calculates the backward gradients over the learning weights
Tells the optimizer to perform one learning step - that is, adjust the model’s learning weights based on the observed gradients for this batch, according to the optimization algorithm we chose
It reports on the loss for every 1000 batches.
Finally, it reports the average per-batch loss for the last 1000 batches, for comparison with a validation run
def train_one_epoch(epoch_index, tb_writer):
running_loss = 0.
last_loss = 0.
# Here, we use enumerate(training_loader) instead of
# iter(training_loader) so that we can track the batch
# index and do some intra-epoch reporting
for i, data in enumerate(training_loader):
# Every data instance is an input + label pair
inputs, labels = data
# Zero your gradients for every batch!
optimizer.zero_grad()
# Make predictions for this batch
outputs = model(inputs)
# Compute the loss and its gradients
loss = loss_fn(outputs, labels)
loss.backward()
# Adjust learning weights
optimizer.step()
# Gather data and report
running_loss += loss.item()
if i % 1000 == 999:
last_loss = running_loss / 1000 # loss per batch
print(' batch {} loss: {}'.format(i + 1, last_loss))
tb_x = epoch_index * len(training_loader) + i + 1
tb_writer.add_scalar('Loss/train', last_loss, tb_x)
running_loss = 0.
return last_loss
Per-Epoch Activity#
There are a couple of things we’ll want to do once per epoch:
Perform validation by checking our relative loss on a set of data that was not used for training, and report this
Save a copy of the model
Here, we’ll do our reporting in TensorBoard. This will require going to the command line to start TensorBoard, and opening it in another browser tab.
# Initializing in a separate cell so we can easily add more epochs to the same run
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp))
epoch_number = 0
EPOCHS = 5
best_vloss = 1_000_000.
for epoch in range(EPOCHS):
print('EPOCH {}:'.format(epoch_number + 1))
# Make sure gradient tracking is on, and do a pass over the data
model.train(True)
avg_loss = train_one_epoch(epoch_number, writer)
running_vloss = 0.0
# Set the model to evaluation mode, disabling dropout and using population
# statistics for batch normalization.
model.eval()
# Disable gradient computation and reduce memory consumption.
with torch.no_grad():
for i, vdata in enumerate(validation_loader):
vinputs, vlabels = vdata
voutputs = model(vinputs)
vloss = loss_fn(voutputs, vlabels)
running_vloss += vloss
avg_vloss = running_vloss / (i + 1)
print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))
# Log the running loss averaged per batch
# for both training and validation
writer.add_scalars('Training vs. Validation Loss',
{ 'Training' : avg_loss, 'Validation' : avg_vloss },
epoch_number + 1)
writer.flush()
# Track best performance, and save the model's state
if avg_vloss < best_vloss:
best_vloss = avg_vloss
model_path = 'model_{}_{}'.format(timestamp, epoch_number)
torch.save(model.state_dict(), model_path)
epoch_number += 1
EPOCH 1:
batch 1000 loss: 1.6658620406091214
batch 2000 loss: 0.8516166052296757
batch 3000 loss: 0.7009272245867177
batch 4000 loss: 0.6560533392401412
batch 5000 loss: 0.5826514636648353
batch 6000 loss: 0.5635461741290055
batch 7000 loss: 0.5068506260314025
batch 8000 loss: 0.5113633444093866
batch 9000 loss: 0.4970339813426544
batch 10000 loss: 0.4938779923019465
batch 11000 loss: 0.4505540374596021
batch 12000 loss: 0.4431168988837162
batch 13000 loss: 0.4388320386710693
batch 14000 loss: 0.43843553768773563
batch 15000 loss: 0.41381557586445705
LOSS train 0.41381557586445705 valid 0.4348176121711731
EPOCH 2:
batch 1000 loss: 0.4061069980097527
batch 2000 loss: 0.37211072959128066
batch 3000 loss: 0.42836699860420774
batch 4000 loss: 0.37676631792844273
batch 5000 loss: 0.3822630588122993
batch 6000 loss: 0.37956657542468747
batch 7000 loss: 0.3796073902762146
batch 8000 loss: 0.3727709586660203
batch 9000 loss: 0.3630514784778934
batch 10000 loss: 0.35778193141293013
batch 11000 loss: 0.33365957739864827
batch 12000 loss: 0.3426332841021067
batch 13000 loss: 0.3589221550418879
batch 14000 loss: 0.35045841016148915
batch 15000 loss: 0.35053903643472584
LOSS train 0.35053903643472584 valid 0.3842392563819885
EPOCH 3:
batch 1000 loss: 0.33440972216619413
batch 2000 loss: 0.3110182780860341
batch 3000 loss: 0.33393270239332923
batch 4000 loss: 0.3056600662749697
batch 5000 loss: 0.33455936154890514
batch 6000 loss: 0.3281591029932242
batch 7000 loss: 0.3320447776057408
batch 8000 loss: 0.3014484698399465
batch 9000 loss: 0.326816707367223
batch 10000 loss: 0.32368432882835624
batch 11000 loss: 0.31286812677708076
batch 12000 loss: 0.32168141807281064
batch 13000 loss: 0.3140142848208925
batch 14000 loss: 0.33175899855807073
batch 15000 loss: 0.31974885813826404
LOSS train 0.31974885813826404 valid 0.3411523401737213
EPOCH 4:
batch 1000 loss: 0.2792649882346741
batch 2000 loss: 0.30695529670898397
batch 3000 loss: 0.28483044201592567
batch 4000 loss: 0.31076253978430757
batch 5000 loss: 0.2800331716003075
batch 6000 loss: 0.27861269556287516
batch 7000 loss: 0.2902559871871872
batch 8000 loss: 0.31863024745455915
batch 9000 loss: 0.30168369303795045
batch 10000 loss: 0.3166546563067386
batch 11000 loss: 0.3051652862815099
batch 12000 loss: 0.30055884081017575
batch 13000 loss: 0.3037734321429016
batch 14000 loss: 0.2887943802849186
batch 15000 loss: 0.2984416477219347
LOSS train 0.2984416477219347 valid 0.33507803082466125
EPOCH 5:
batch 1000 loss: 0.26119035481381253
batch 2000 loss: 0.2737489958937615
batch 3000 loss: 0.283918988411373
batch 4000 loss: 0.2724920483358019
batch 5000 loss: 0.28245116227514516
batch 6000 loss: 0.27361438447707404
batch 7000 loss: 0.28149854283319475
batch 8000 loss: 0.2817577930703628
batch 9000 loss: 0.299952932070355
batch 10000 loss: 0.28111812591029045
batch 11000 loss: 0.2771015756965207
batch 12000 loss: 0.28317676019544885
batch 13000 loss: 0.29391766782395645
batch 14000 loss: 0.25625949264538483
batch 15000 loss: 0.278206290835853
LOSS train 0.278206290835853 valid 0.3171177804470062
To load a saved version of the model:
saved_model = GarmentClassifier()
saved_model.load_state_dict(torch.load(PATH))
Once you’ve loaded the model, it’s ready for whatever you need it for - more training, inference, or analysis.
Note that if your model has constructor parameters that affect model structure, you’ll need to provide them and configure the model identically to the state in which it was saved.
Other Resources#
Docs on the data utilities, including Dataset and DataLoader, at pytorch.org
A note on the use of pinned memory for GPU training
Documentation on the datasets available in TorchVision, TorchText, and TorchAudio
Documentation on the loss functions available in PyTorch
Documentation on the torch.optim package, which includes optimizers and related tools, such as learning rate scheduling
A detailed tutorial on saving and loading models
The Tutorials section of pytorch.org contains tutorials on a broad variety of training tasks, including classification in different domains, generative adversarial networks, reinforcement learning, and more
Total running time of the script: (2 minutes 54.545 seconds)