Note
Go to the end to download the full example code.
Introduction || Tensors || Autograd || Building Models || TensorBoard Support || Training Models || Model Understanding
Training with PyTorch#
Created On: Nov 30, 2021 | Last Updated: May 31, 2023 | Last Verified: Nov 05, 2024
Follow along with the video below or on youtube.
Introduction#
In past videos, we’ve discussed and demonstrated:
Building models with the neural network layers and functions of the torch.nn module
The mechanics of automated gradient computation, which is central to gradient-based model training
Using TensorBoard to visualize training progress and other activities
In this video, we’ll be adding some new tools to your inventory:
We’ll get familiar with the dataset and dataloader abstractions, and how they ease the process of feeding data to your model during a training loop
We’ll discuss specific loss functions and when to use them
We’ll look at PyTorch optimizers, which implement algorithms to adjust model weights based on the outcome of a loss function
Finally, we’ll pull all of these together and see a full PyTorch training loop in action.
Dataset and DataLoader#
The Dataset
and DataLoader
classes encapsulate the process of
pulling your data from storage and exposing it to your training loop in
batches.
The Dataset
is responsible for accessing and processing single
instances of data.
The DataLoader
pulls instances of data from the Dataset
(either
automatically or with a sampler that you define), collects them in
batches, and returns them for consumption by your training loop. The
DataLoader
works with all kinds of datasets, regardless of the type
of data they contain.
For this tutorial, we’ll be using the Fashion-MNIST dataset provided by
TorchVision. We use torchvision.transforms.Normalize()
to
zero-center and normalize the distribution of the image tile content,
and download both training and validation data splits.
import torch
import torchvision
import torchvision.transforms as transforms
# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))])
# Create datasets for training & validation, download if necessary
training_set = torchvision.datasets.FashionMNIST('./data', train=True, transform=transform, download=True)
validation_set = torchvision.datasets.FashionMNIST('./data', train=False, transform=transform, download=True)
# Create data loaders for our datasets; shuffle for training, not for validation
training_loader = torch.utils.data.DataLoader(training_set, batch_size=4, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=4, shuffle=False)
# Class labels
classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')
# Report split sizes
print('Training set has {} instances'.format(len(training_set)))
print('Validation set has {} instances'.format(len(validation_set)))
0%| | 0.00/26.4M [00:00<?, ?B/s]
0%| | 65.5k/26.4M [00:00<01:12, 365kB/s]
1%| | 229k/26.4M [00:00<00:38, 684kB/s]
3%|▎ | 918k/26.4M [00:00<00:09, 2.58MB/s]
7%|▋ | 1.93M/26.4M [00:00<00:05, 4.11MB/s]
24%|██▍ | 6.32M/26.4M [00:00<00:01, 14.5MB/s]
37%|███▋ | 9.80M/26.4M [00:00<00:00, 17.1MB/s]
56%|█████▌ | 14.7M/26.4M [00:01<00:00, 24.6MB/s]
71%|███████ | 18.8M/26.4M [00:01<00:00, 24.8MB/s]
89%|████████▉ | 23.6M/26.4M [00:01<00:00, 29.6MB/s]
100%|██████████| 26.4M/26.4M [00:01<00:00, 19.4MB/s]
0%| | 0.00/29.5k [00:00<?, ?B/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 327kB/s]
0%| | 0.00/4.42M [00:00<?, ?B/s]
1%|▏ | 65.5k/4.42M [00:00<00:12, 360kB/s]
5%|▌ | 229k/4.42M [00:00<00:06, 680kB/s]
21%|██▏ | 950k/4.42M [00:00<00:01, 2.18MB/s]
87%|████████▋ | 3.83M/4.42M [00:00<00:00, 7.60MB/s]
100%|██████████| 4.42M/4.42M [00:00<00:00, 6.07MB/s]
0%| | 0.00/5.15k [00:00<?, ?B/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 60.5MB/s]
Training set has 60000 instances
Validation set has 10000 instances
As always, let’s visualize the data as a sanity check:
import matplotlib.pyplot as plt
import numpy as np
# Helper function for inline image display
def matplotlib_imshow(img, one_channel=False):
if one_channel:
img = img.mean(dim=0)
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
if one_channel:
plt.imshow(npimg, cmap="Greys")
else:
plt.imshow(np.transpose(npimg, (1, 2, 0)))
dataiter = iter(training_loader)
images, labels = next(dataiter)
# Create a grid from the images and show them
img_grid = torchvision.utils.make_grid(images)
matplotlib_imshow(img_grid, one_channel=True)
print(' '.join(classes[labels[j]] for j in range(4)))

Trouser Bag Sneaker T-shirt/top
The Model#
The model we’ll use in this example is a variant of LeNet-5 - it should be familiar if you’ve watched the previous videos in this series.
import torch.nn as nn
import torch.nn.functional as F
# PyTorch models inherit from torch.nn.Module
class GarmentClassifier(nn.Module):
def __init__(self):
super(GarmentClassifier, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 4 * 4, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 4 * 4)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
model = GarmentClassifier()
Loss Function#
For this example, we’ll be using a cross-entropy loss. For demonstration purposes, we’ll create batches of dummy output and label values, run them through the loss function, and examine the result.
loss_fn = torch.nn.CrossEntropyLoss()
# NB: Loss functions expect data in batches, so we're creating batches of 4
# Represents the model's confidence in each of the 10 classes for a given input
dummy_outputs = torch.rand(4, 10)
# Represents the correct class among the 10 being tested
dummy_labels = torch.tensor([1, 5, 3, 7])
print(dummy_outputs)
print(dummy_labels)
loss = loss_fn(dummy_outputs, dummy_labels)
print('Total loss for this batch: {}'.format(loss.item()))
tensor([[0.1726, 0.5519, 0.5408, 0.3087, 0.5926, 0.1322, 0.3521, 0.1185, 0.8817,
0.7792],
[0.3095, 0.1187, 0.0439, 0.9088, 0.4309, 0.8192, 0.1736, 0.2536, 0.1780,
0.2569],
[0.5299, 0.8088, 0.2074, 0.7737, 0.7397, 0.5035, 0.0671, 0.1186, 0.0174,
0.9317],
[0.8866, 0.2105, 0.7979, 0.4938, 0.5239, 0.8345, 0.7334, 0.6203, 0.1057,
0.7327]])
tensor([1, 5, 3, 7])
Total loss for this batch: 2.1142544746398926
Optimizer#
For this example, we’ll be using simple stochastic gradient descent with momentum.
It can be instructive to try some variations on this optimization scheme:
Learning rate determines the size of the steps the optimizer takes. What does a different learning rate do to the your training results, in terms of accuracy and convergence time?
Momentum nudges the optimizer in the direction of strongest gradient over multiple steps. What does changing this value do to your results?
Try some different optimization algorithms, such as averaged SGD, Adagrad, or Adam. How do your results differ?
# Optimizers specified in the torch.optim package
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
The Training Loop#
Below, we have a function that performs one training epoch. It enumerates data from the DataLoader, and on each pass of the loop does the following:
Gets a batch of training data from the DataLoader
Zeros the optimizer’s gradients
Performs an inference - that is, gets predictions from the model for an input batch
Calculates the loss for that set of predictions vs. the labels on the dataset
Calculates the backward gradients over the learning weights
Tells the optimizer to perform one learning step - that is, adjust the model’s learning weights based on the observed gradients for this batch, according to the optimization algorithm we chose
It reports on the loss for every 1000 batches.
Finally, it reports the average per-batch loss for the last 1000 batches, for comparison with a validation run
def train_one_epoch(epoch_index, tb_writer):
running_loss = 0.
last_loss = 0.
# Here, we use enumerate(training_loader) instead of
# iter(training_loader) so that we can track the batch
# index and do some intra-epoch reporting
for i, data in enumerate(training_loader):
# Every data instance is an input + label pair
inputs, labels = data
# Zero your gradients for every batch!
optimizer.zero_grad()
# Make predictions for this batch
outputs = model(inputs)
# Compute the loss and its gradients
loss = loss_fn(outputs, labels)
loss.backward()
# Adjust learning weights
optimizer.step()
# Gather data and report
running_loss += loss.item()
if i % 1000 == 999:
last_loss = running_loss / 1000 # loss per batch
print(' batch {} loss: {}'.format(i + 1, last_loss))
tb_x = epoch_index * len(training_loader) + i + 1
tb_writer.add_scalar('Loss/train', last_loss, tb_x)
running_loss = 0.
return last_loss
Per-Epoch Activity#
There are a couple of things we’ll want to do once per epoch:
Perform validation by checking our relative loss on a set of data that was not used for training, and report this
Save a copy of the model
Here, we’ll do our reporting in TensorBoard. This will require going to the command line to start TensorBoard, and opening it in another browser tab.
# Initializing in a separate cell so we can easily add more epochs to the same run
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp))
epoch_number = 0
EPOCHS = 5
best_vloss = 1_000_000.
for epoch in range(EPOCHS):
print('EPOCH {}:'.format(epoch_number + 1))
# Make sure gradient tracking is on, and do a pass over the data
model.train(True)
avg_loss = train_one_epoch(epoch_number, writer)
running_vloss = 0.0
# Set the model to evaluation mode, disabling dropout and using population
# statistics for batch normalization.
model.eval()
# Disable gradient computation and reduce memory consumption.
with torch.no_grad():
for i, vdata in enumerate(validation_loader):
vinputs, vlabels = vdata
voutputs = model(vinputs)
vloss = loss_fn(voutputs, vlabels)
running_vloss += vloss
avg_vloss = running_vloss / (i + 1)
print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))
# Log the running loss averaged per batch
# for both training and validation
writer.add_scalars('Training vs. Validation Loss',
{ 'Training' : avg_loss, 'Validation' : avg_vloss },
epoch_number + 1)
writer.flush()
# Track best performance, and save the model's state
if avg_vloss < best_vloss:
best_vloss = avg_vloss
model_path = 'model_{}_{}'.format(timestamp, epoch_number)
torch.save(model.state_dict(), model_path)
epoch_number += 1
EPOCH 1:
batch 1000 loss: 2.0368239536881445
batch 2000 loss: 0.9447713082060217
batch 3000 loss: 0.7512045184113085
batch 4000 loss: 0.672041746229399
batch 5000 loss: 0.6404219433853868
batch 6000 loss: 0.5876974876336754
batch 7000 loss: 0.5642247462056111
batch 8000 loss: 0.5238452655504807
batch 9000 loss: 0.5183932892873417
batch 10000 loss: 0.48882931664795615
batch 11000 loss: 0.4777649353200104
batch 12000 loss: 0.4593426596976351
batch 13000 loss: 0.47381507632660214
batch 14000 loss: 0.4397967921900563
batch 15000 loss: 0.4349743271954358
LOSS train 0.4349743271954358 valid 0.43963736295700073
EPOCH 2:
batch 1000 loss: 0.41540096975956115
batch 2000 loss: 0.43442853027596723
batch 3000 loss: 0.4131184138438548
batch 4000 loss: 0.4021981662045582
batch 5000 loss: 0.3914233583531459
batch 6000 loss: 0.387463094651117
batch 7000 loss: 0.38125657067866997
batch 8000 loss: 0.37205292124988043
batch 9000 loss: 0.36828389775776305
batch 10000 loss: 0.38511645725462584
batch 11000 loss: 0.3565248569853429
batch 12000 loss: 0.3866269055126468
batch 13000 loss: 0.3703528124356526
batch 14000 loss: 0.3600774475395447
batch 15000 loss: 0.34374926845679876
LOSS train 0.34374926845679876 valid 0.3787817656993866
EPOCH 3:
batch 1000 loss: 0.342162811011789
batch 2000 loss: 0.34274276750002175
batch 3000 loss: 0.3274373286683112
batch 4000 loss: 0.3456361875798939
batch 5000 loss: 0.34847399987274547
batch 6000 loss: 0.33520299241051543
batch 7000 loss: 0.3367152308954974
batch 8000 loss: 0.35517711870645874
batch 9000 loss: 0.3547138352394104
batch 10000 loss: 0.3318984882144505
batch 11000 loss: 0.32318603456309936
batch 12000 loss: 0.3174168571334158
batch 13000 loss: 0.32815200396331057
batch 14000 loss: 0.33950525665492753
batch 15000 loss: 0.30512224544102357
LOSS train 0.30512224544102357 valid 0.34092164039611816
EPOCH 4:
batch 1000 loss: 0.2952195990138716
batch 2000 loss: 0.29075468799601367
batch 3000 loss: 0.3111583511932258
batch 4000 loss: 0.32493375527830176
batch 5000 loss: 0.30718161875175426
batch 6000 loss: 0.3095859206026362
batch 7000 loss: 0.30101370186725396
batch 8000 loss: 0.3157856839992419
batch 9000 loss: 0.30477857304843203
batch 10000 loss: 0.3091353414664045
batch 11000 loss: 0.30009497354355696
batch 12000 loss: 0.2978665009691904
batch 13000 loss: 0.3133462607739348
batch 14000 loss: 0.31493555417069
batch 15000 loss: 0.30068287946861527
LOSS train 0.30068287946861527 valid 0.33888134360313416
EPOCH 5:
batch 1000 loss: 0.2790463499578473
batch 2000 loss: 0.28388226319586646
batch 3000 loss: 0.28616013707439925
batch 4000 loss: 0.2801440930921235
batch 5000 loss: 0.29005118502126426
batch 6000 loss: 0.29498661262870884
batch 7000 loss: 0.2799844486888178
batch 8000 loss: 0.2823787230397502
batch 9000 loss: 0.2747978572060092
batch 10000 loss: 0.30346416945975213
batch 11000 loss: 0.29198151576622333
batch 12000 loss: 0.2743332377418847
batch 13000 loss: 0.3042091254772495
batch 14000 loss: 0.28699204632262443
batch 15000 loss: 0.27387973593545756
LOSS train 0.27387973593545756 valid 0.32991060614585876
To load a saved version of the model:
saved_model = GarmentClassifier()
saved_model.load_state_dict(torch.load(PATH))
Once you’ve loaded the model, it’s ready for whatever you need it for - more training, inference, or analysis.
Note that if your model has constructor parameters that affect model structure, you’ll need to provide them and configure the model identically to the state in which it was saved.
Other Resources#
Docs on the data utilities, including Dataset and DataLoader, at pytorch.org
A note on the use of pinned memory for GPU training
Documentation on the datasets available in TorchVision, TorchText, and TorchAudio
Documentation on the loss functions available in PyTorch
Documentation on the torch.optim package, which includes optimizers and related tools, such as learning rate scheduling
A detailed tutorial on saving and loading models
The Tutorials section of pytorch.org contains tutorials on a broad variety of training tasks, including classification in different domains, generative adversarial networks, reinforcement learning, and more
Total running time of the script: (3 minutes 2.614 seconds)