Note
Go to the end to download the full example code.
Introduction || Tensors || Autograd || Building Models || TensorBoard Support || Training Models || Model Understanding
Training with PyTorch#
Created On: Nov 30, 2021 | Last Updated: May 31, 2023 | Last Verified: Nov 05, 2024
Follow along with the video below or on youtube.
Introduction#
In past videos, we’ve discussed and demonstrated:
Building models with the neural network layers and functions of the torch.nn module
The mechanics of automated gradient computation, which is central to gradient-based model training
Using TensorBoard to visualize training progress and other activities
In this video, we’ll be adding some new tools to your inventory:
We’ll get familiar with the dataset and dataloader abstractions, and how they ease the process of feeding data to your model during a training loop
We’ll discuss specific loss functions and when to use them
We’ll look at PyTorch optimizers, which implement algorithms to adjust model weights based on the outcome of a loss function
Finally, we’ll pull all of these together and see a full PyTorch training loop in action.
Dataset and DataLoader#
The Dataset
and DataLoader
classes encapsulate the process of
pulling your data from storage and exposing it to your training loop in
batches.
The Dataset
is responsible for accessing and processing single
instances of data.
The DataLoader
pulls instances of data from the Dataset
(either
automatically or with a sampler that you define), collects them in
batches, and returns them for consumption by your training loop. The
DataLoader
works with all kinds of datasets, regardless of the type
of data they contain.
For this tutorial, we’ll be using the Fashion-MNIST dataset provided by
TorchVision. We use torchvision.transforms.Normalize()
to
zero-center and normalize the distribution of the image tile content,
and download both training and validation data splits.
import torch
import torchvision
import torchvision.transforms as transforms
# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))])
# Create datasets for training & validation, download if necessary
training_set = torchvision.datasets.FashionMNIST('./data', train=True, transform=transform, download=True)
validation_set = torchvision.datasets.FashionMNIST('./data', train=False, transform=transform, download=True)
# Create data loaders for our datasets; shuffle for training, not for validation
training_loader = torch.utils.data.DataLoader(training_set, batch_size=4, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=4, shuffle=False)
# Class labels
classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')
# Report split sizes
print('Training set has {} instances'.format(len(training_set)))
print('Validation set has {} instances'.format(len(validation_set)))
0%| | 0.00/26.4M [00:00<?, ?B/s]
0%| | 65.5k/26.4M [00:00<01:13, 361kB/s]
1%| | 229k/26.4M [00:00<00:38, 677kB/s]
3%|▎ | 918k/26.4M [00:00<00:12, 2.09MB/s]
14%|█▍ | 3.67M/26.4M [00:00<00:03, 7.21MB/s]
33%|███▎ | 8.81M/26.4M [00:00<00:01, 14.8MB/s]
56%|█████▋ | 14.9M/26.4M [00:01<00:00, 21.0MB/s]
78%|███████▊ | 20.7M/26.4M [00:01<00:00, 28.8MB/s]
91%|█████████ | 24.1M/26.4M [00:01<00:00, 25.4MB/s]
100%|██████████| 26.4M/26.4M [00:01<00:00, 19.2MB/s]
0%| | 0.00/29.5k [00:00<?, ?B/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 324kB/s]
0%| | 0.00/4.42M [00:00<?, ?B/s]
1%|▏ | 65.5k/4.42M [00:00<00:12, 360kB/s]
5%|▌ | 229k/4.42M [00:00<00:06, 678kB/s]
19%|█▉ | 852k/4.42M [00:00<00:01, 1.93MB/s]
79%|███████▊ | 3.47M/4.42M [00:00<00:00, 6.85MB/s]
100%|██████████| 4.42M/4.42M [00:00<00:00, 6.05MB/s]
0%| | 0.00/5.15k [00:00<?, ?B/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 56.2MB/s]
Training set has 60000 instances
Validation set has 10000 instances
As always, let’s visualize the data as a sanity check:
import matplotlib.pyplot as plt
import numpy as np
# Helper function for inline image display
def matplotlib_imshow(img, one_channel=False):
if one_channel:
img = img.mean(dim=0)
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
if one_channel:
plt.imshow(npimg, cmap="Greys")
else:
plt.imshow(np.transpose(npimg, (1, 2, 0)))
dataiter = iter(training_loader)
images, labels = next(dataiter)
# Create a grid from the images and show them
img_grid = torchvision.utils.make_grid(images)
matplotlib_imshow(img_grid, one_channel=True)
print(' '.join(classes[labels[j]] for j in range(4)))

T-shirt/top Ankle Boot T-shirt/top Ankle Boot
The Model#
The model we’ll use in this example is a variant of LeNet-5 - it should be familiar if you’ve watched the previous videos in this series.
import torch.nn as nn
import torch.nn.functional as F
# PyTorch models inherit from torch.nn.Module
class GarmentClassifier(nn.Module):
def __init__(self):
super(GarmentClassifier, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 4 * 4, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 4 * 4)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
model = GarmentClassifier()
Loss Function#
For this example, we’ll be using a cross-entropy loss. For demonstration purposes, we’ll create batches of dummy output and label values, run them through the loss function, and examine the result.
loss_fn = torch.nn.CrossEntropyLoss()
# NB: Loss functions expect data in batches, so we're creating batches of 4
# Represents the model's confidence in each of the 10 classes for a given input
dummy_outputs = torch.rand(4, 10)
# Represents the correct class among the 10 being tested
dummy_labels = torch.tensor([1, 5, 3, 7])
print(dummy_outputs)
print(dummy_labels)
loss = loss_fn(dummy_outputs, dummy_labels)
print('Total loss for this batch: {}'.format(loss.item()))
tensor([[0.4136, 0.5780, 0.6322, 0.5602, 0.9690, 0.6201, 0.1045, 0.2775, 0.3988,
0.2362],
[0.4783, 0.0956, 0.0695, 0.2914, 0.5784, 0.6135, 0.2122, 0.4925, 0.8731,
0.5130],
[0.1394, 0.5479, 0.4386, 0.3213, 0.0195, 0.1422, 0.4464, 0.0376, 0.3112,
0.7655],
[0.4853, 0.8517, 0.1313, 0.8095, 0.1210, 0.6882, 0.2715, 0.6303, 0.7685,
0.7147]])
tensor([1, 5, 3, 7])
Total loss for this batch: 2.2371749877929688
Optimizer#
For this example, we’ll be using simple stochastic gradient descent with momentum.
It can be instructive to try some variations on this optimization scheme:
Learning rate determines the size of the steps the optimizer takes. What does a different learning rate do to the your training results, in terms of accuracy and convergence time?
Momentum nudges the optimizer in the direction of strongest gradient over multiple steps. What does changing this value do to your results?
Try some different optimization algorithms, such as averaged SGD, Adagrad, or Adam. How do your results differ?
# Optimizers specified in the torch.optim package
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
The Training Loop#
Below, we have a function that performs one training epoch. It enumerates data from the DataLoader, and on each pass of the loop does the following:
Gets a batch of training data from the DataLoader
Zeros the optimizer’s gradients
Performs an inference - that is, gets predictions from the model for an input batch
Calculates the loss for that set of predictions vs. the labels on the dataset
Calculates the backward gradients over the learning weights
Tells the optimizer to perform one learning step - that is, adjust the model’s learning weights based on the observed gradients for this batch, according to the optimization algorithm we chose
It reports on the loss for every 1000 batches.
Finally, it reports the average per-batch loss for the last 1000 batches, for comparison with a validation run
def train_one_epoch(epoch_index, tb_writer):
running_loss = 0.
last_loss = 0.
# Here, we use enumerate(training_loader) instead of
# iter(training_loader) so that we can track the batch
# index and do some intra-epoch reporting
for i, data in enumerate(training_loader):
# Every data instance is an input + label pair
inputs, labels = data
# Zero your gradients for every batch!
optimizer.zero_grad()
# Make predictions for this batch
outputs = model(inputs)
# Compute the loss and its gradients
loss = loss_fn(outputs, labels)
loss.backward()
# Adjust learning weights
optimizer.step()
# Gather data and report
running_loss += loss.item()
if i % 1000 == 999:
last_loss = running_loss / 1000 # loss per batch
print(' batch {} loss: {}'.format(i + 1, last_loss))
tb_x = epoch_index * len(training_loader) + i + 1
tb_writer.add_scalar('Loss/train', last_loss, tb_x)
running_loss = 0.
return last_loss
Per-Epoch Activity#
There are a couple of things we’ll want to do once per epoch:
Perform validation by checking our relative loss on a set of data that was not used for training, and report this
Save a copy of the model
Here, we’ll do our reporting in TensorBoard. This will require going to the command line to start TensorBoard, and opening it in another browser tab.
# Initializing in a separate cell so we can easily add more epochs to the same run
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp))
epoch_number = 0
EPOCHS = 5
best_vloss = 1_000_000.
for epoch in range(EPOCHS):
print('EPOCH {}:'.format(epoch_number + 1))
# Make sure gradient tracking is on, and do a pass over the data
model.train(True)
avg_loss = train_one_epoch(epoch_number, writer)
running_vloss = 0.0
# Set the model to evaluation mode, disabling dropout and using population
# statistics for batch normalization.
model.eval()
# Disable gradient computation and reduce memory consumption.
with torch.no_grad():
for i, vdata in enumerate(validation_loader):
vinputs, vlabels = vdata
voutputs = model(vinputs)
vloss = loss_fn(voutputs, vlabels)
running_vloss += vloss
avg_vloss = running_vloss / (i + 1)
print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))
# Log the running loss averaged per batch
# for both training and validation
writer.add_scalars('Training vs. Validation Loss',
{ 'Training' : avg_loss, 'Validation' : avg_vloss },
epoch_number + 1)
writer.flush()
# Track best performance, and save the model's state
if avg_vloss < best_vloss:
best_vloss = avg_vloss
model_path = 'model_{}_{}'.format(timestamp, epoch_number)
torch.save(model.state_dict(), model_path)
epoch_number += 1
EPOCH 1:
batch 1000 loss: 1.7758210811205208
batch 2000 loss: 0.8684514756985009
batch 3000 loss: 0.7269933776920662
batch 4000 loss: 0.6264328583646566
batch 5000 loss: 0.5713831033515744
batch 6000 loss: 0.5716565410976764
batch 7000 loss: 0.523069411519682
batch 8000 loss: 0.5059755043592304
batch 9000 loss: 0.46764329217243356
batch 10000 loss: 0.4756787621000549
batch 11000 loss: 0.45533774535809063
batch 12000 loss: 0.434188809020794
batch 13000 loss: 0.4432140271391254
batch 14000 loss: 0.433044406996778
batch 15000 loss: 0.4146991293230094
LOSS train 0.4146991293230094 valid 0.4503612220287323
EPOCH 2:
batch 1000 loss: 0.39955700305808567
batch 2000 loss: 0.41123968731809873
batch 3000 loss: 0.3859535540444776
batch 4000 loss: 0.3470500040081679
batch 5000 loss: 0.3668477054270334
batch 6000 loss: 0.36335744826548033
batch 7000 loss: 0.38008951136260294
batch 8000 loss: 0.3750837282325665
batch 9000 loss: 0.36501676612862505
batch 10000 loss: 0.35058384872134774
batch 11000 loss: 0.36770042545941395
batch 12000 loss: 0.3568097979603335
batch 13000 loss: 0.3509265405783517
batch 14000 loss: 0.3462942300561117
batch 15000 loss: 0.3561768907784135
LOSS train 0.3561768907784135 valid 0.3701411485671997
EPOCH 3:
batch 1000 loss: 0.3243762938613363
batch 2000 loss: 0.3167930108278233
batch 3000 loss: 0.3454286119421013
batch 4000 loss: 0.3484262980411295
batch 5000 loss: 0.32934922833560265
batch 6000 loss: 0.3102569465649849
batch 7000 loss: 0.32708222679484605
batch 8000 loss: 0.3257536272257348
batch 9000 loss: 0.3355292390795657
batch 10000 loss: 0.30939457686892274
batch 11000 loss: 0.3203176606516936
batch 12000 loss: 0.31400518341270073
batch 13000 loss: 0.32281278402548197
batch 14000 loss: 0.31659204009777747
batch 15000 loss: 0.3119617188770499
LOSS train 0.3119617188770499 valid 0.34070858359336853
EPOCH 4:
batch 1000 loss: 0.29988487648974843
batch 2000 loss: 0.29513267787147196
batch 3000 loss: 0.29591830945741093
batch 4000 loss: 0.29502063494139114
batch 5000 loss: 0.2923672856846097
batch 6000 loss: 0.29842070014849015
batch 7000 loss: 0.32958982204912174
batch 8000 loss: 0.29485688532363563
batch 9000 loss: 0.2765466377943376
batch 10000 loss: 0.31751291519972435
batch 11000 loss: 0.29047842672202384
batch 12000 loss: 0.30218106864616856
batch 13000 loss: 0.28686269072255893
batch 14000 loss: 0.30787017451071735
batch 15000 loss: 0.2951680972841132
LOSS train 0.2951680972841132 valid 0.3243757486343384
EPOCH 5:
batch 1000 loss: 0.28259117489257185
batch 2000 loss: 0.2710777925652437
batch 3000 loss: 0.2820857833838163
batch 4000 loss: 0.28647030706787335
batch 5000 loss: 0.2911521529511374
batch 6000 loss: 0.28676370393694744
batch 7000 loss: 0.3009271204024071
batch 8000 loss: 0.27809979932467105
batch 9000 loss: 0.2641369462262956
batch 10000 loss: 0.281442152015843
batch 11000 loss: 0.2740388377367817
batch 12000 loss: 0.2635333917985927
batch 13000 loss: 0.28544877841679406
batch 14000 loss: 0.28253569301004244
batch 15000 loss: 0.2722428723026424
LOSS train 0.2722428723026424 valid 0.3119978904724121
To load a saved version of the model:
saved_model = GarmentClassifier()
saved_model.load_state_dict(torch.load(PATH))
Once you’ve loaded the model, it’s ready for whatever you need it for - more training, inference, or analysis.
Note that if your model has constructor parameters that affect model structure, you’ll need to provide them and configure the model identically to the state in which it was saved.
Other Resources#
Docs on the data utilities, including Dataset and DataLoader, at pytorch.org
A note on the use of pinned memory for GPU training
Documentation on the datasets available in TorchVision, TorchText, and TorchAudio
Documentation on the loss functions available in PyTorch
Documentation on the torch.optim package, which includes optimizers and related tools, such as learning rate scheduling
A detailed tutorial on saving and loading models
The Tutorials section of pytorch.org contains tutorials on a broad variety of training tasks, including classification in different domains, generative adversarial networks, reinforcement learning, and more
Total running time of the script: (3 minutes 8.603 seconds)