Note
Go to the end to download the full example code.
Introduction || Tensors || Autograd || Building Models || TensorBoard Support || Training Models || Model Understanding
Training with PyTorch#
Created On: Nov 30, 2021 | Last Updated: May 31, 2023 | Last Verified: Nov 05, 2024
Follow along with the video below or on youtube.
Introduction#
In past videos, we’ve discussed and demonstrated:
Building models with the neural network layers and functions of the torch.nn module
The mechanics of automated gradient computation, which is central to gradient-based model training
Using TensorBoard to visualize training progress and other activities
In this video, we’ll be adding some new tools to your inventory:
We’ll get familiar with the dataset and dataloader abstractions, and how they ease the process of feeding data to your model during a training loop
We’ll discuss specific loss functions and when to use them
We’ll look at PyTorch optimizers, which implement algorithms to adjust model weights based on the outcome of a loss function
Finally, we’ll pull all of these together and see a full PyTorch training loop in action.
Dataset and DataLoader#
The Dataset and DataLoader classes encapsulate the process of
pulling your data from storage and exposing it to your training loop in
batches.
The Dataset is responsible for accessing and processing single
instances of data.
The DataLoader pulls instances of data from the Dataset (either
automatically or with a sampler that you define), collects them in
batches, and returns them for consumption by your training loop. The
DataLoader works with all kinds of datasets, regardless of the type
of data they contain.
For this tutorial, we’ll be using the Fashion-MNIST dataset provided by
TorchVision. We use torchvision.transforms.Normalize() to
zero-center and normalize the distribution of the image tile content,
and download both training and validation data splits.
import torch
import torchvision
import torchvision.transforms as transforms
# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))])
# Create datasets for training & validation, download if necessary
training_set = torchvision.datasets.FashionMNIST('./data', train=True, transform=transform, download=True)
validation_set = torchvision.datasets.FashionMNIST('./data', train=False, transform=transform, download=True)
# Create data loaders for our datasets; shuffle for training, not for validation
training_loader = torch.utils.data.DataLoader(training_set, batch_size=4, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=4, shuffle=False)
# Class labels
classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')
# Report split sizes
print('Training set has {} instances'.format(len(training_set)))
print('Validation set has {} instances'.format(len(validation_set)))
0%| | 0.00/26.4M [00:00<?, ?B/s]
0%| | 65.5k/26.4M [00:00<01:13, 360kB/s]
1%| | 229k/26.4M [00:00<00:38, 674kB/s]
3%|▎ | 918k/26.4M [00:00<00:12, 2.08MB/s]
14%|█▍ | 3.67M/26.4M [00:00<00:03, 7.18MB/s]
37%|███▋ | 9.83M/26.4M [00:00<00:00, 16.7MB/s]
60%|█████▉ | 15.8M/26.4M [00:01<00:00, 22.0MB/s]
82%|████████▏ | 21.7M/26.4M [00:01<00:00, 25.3MB/s]
100%|██████████| 26.4M/26.4M [00:01<00:00, 19.1MB/s]
0%| | 0.00/29.5k [00:00<?, ?B/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 323kB/s]
0%| | 0.00/4.42M [00:00<?, ?B/s]
1%|▏ | 65.5k/4.42M [00:00<00:12, 360kB/s]
4%|▍ | 197k/4.42M [00:00<00:07, 575kB/s]
19%|█▊ | 819k/4.42M [00:00<00:01, 1.88MB/s]
74%|███████▍ | 3.28M/4.42M [00:00<00:00, 6.49MB/s]
100%|██████████| 4.42M/4.42M [00:00<00:00, 6.06MB/s]
0%| | 0.00/5.15k [00:00<?, ?B/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 52.8MB/s]
Training set has 60000 instances
Validation set has 10000 instances
As always, let’s visualize the data as a sanity check:
import matplotlib.pyplot as plt
import numpy as np
# Helper function for inline image display
def matplotlib_imshow(img, one_channel=False):
if one_channel:
img = img.mean(dim=0)
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
if one_channel:
plt.imshow(npimg, cmap="Greys")
else:
plt.imshow(np.transpose(npimg, (1, 2, 0)))
dataiter = iter(training_loader)
images, labels = next(dataiter)
# Create a grid from the images and show them
img_grid = torchvision.utils.make_grid(images)
matplotlib_imshow(img_grid, one_channel=True)
print(' '.join(classes[labels[j]] for j in range(4)))

Coat Sneaker Dress Dress
The Model#
The model we’ll use in this example is a variant of LeNet-5 - it should be familiar if you’ve watched the previous videos in this series.
import torch.nn as nn
import torch.nn.functional as F
# PyTorch models inherit from torch.nn.Module
class GarmentClassifier(nn.Module):
def __init__(self):
super(GarmentClassifier, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 4 * 4, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 4 * 4)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
model = GarmentClassifier()
Loss Function#
For this example, we’ll be using a cross-entropy loss. For demonstration purposes, we’ll create batches of dummy output and label values, run them through the loss function, and examine the result.
loss_fn = torch.nn.CrossEntropyLoss()
# NB: Loss functions expect data in batches, so we're creating batches of 4
# Represents the model's confidence in each of the 10 classes for a given input
dummy_outputs = torch.rand(4, 10)
# Represents the correct class among the 10 being tested
dummy_labels = torch.tensor([1, 5, 3, 7])
print(dummy_outputs)
print(dummy_labels)
loss = loss_fn(dummy_outputs, dummy_labels)
print('Total loss for this batch: {}'.format(loss.item()))
tensor([[0.4271, 0.6384, 0.3452, 0.7453, 0.4921, 0.3817, 0.5209, 0.0885, 0.9970,
0.9892],
[0.8800, 0.9021, 0.6677, 0.9624, 0.0349, 0.5844, 0.1598, 0.6857, 0.1892,
0.1251],
[0.2240, 0.7840, 0.8938, 0.4519, 0.1508, 0.3038, 0.8100, 0.1413, 0.4139,
0.4835],
[0.9473, 0.6924, 0.9306, 0.2292, 0.6798, 0.8082, 0.4483, 0.2938, 0.5292,
0.7377]])
tensor([1, 5, 3, 7])
Total loss for this batch: 2.3937389850616455
Optimizer#
For this example, we’ll be using simple stochastic gradient descent with momentum.
It can be instructive to try some variations on this optimization scheme:
Learning rate determines the size of the steps the optimizer takes. What does a different learning rate do to the your training results, in terms of accuracy and convergence time?
Momentum nudges the optimizer in the direction of strongest gradient over multiple steps. What does changing this value do to your results?
Try some different optimization algorithms, such as averaged SGD, Adagrad, or Adam. How do your results differ?
# Optimizers specified in the torch.optim package
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
The Training Loop#
Below, we have a function that performs one training epoch. It enumerates data from the DataLoader, and on each pass of the loop does the following:
Gets a batch of training data from the DataLoader
Zeros the optimizer’s gradients
Performs an inference - that is, gets predictions from the model for an input batch
Calculates the loss for that set of predictions vs. the labels on the dataset
Calculates the backward gradients over the learning weights
Tells the optimizer to perform one learning step - that is, adjust the model’s learning weights based on the observed gradients for this batch, according to the optimization algorithm we chose
It reports on the loss for every 1000 batches.
Finally, it reports the average per-batch loss for the last 1000 batches, for comparison with a validation run
def train_one_epoch(epoch_index, tb_writer):
running_loss = 0.
last_loss = 0.
# Here, we use enumerate(training_loader) instead of
# iter(training_loader) so that we can track the batch
# index and do some intra-epoch reporting
for i, data in enumerate(training_loader):
# Every data instance is an input + label pair
inputs, labels = data
# Zero your gradients for every batch!
optimizer.zero_grad()
# Make predictions for this batch
outputs = model(inputs)
# Compute the loss and its gradients
loss = loss_fn(outputs, labels)
loss.backward()
# Adjust learning weights
optimizer.step()
# Gather data and report
running_loss += loss.item()
if i % 1000 == 999:
last_loss = running_loss / 1000 # loss per batch
print(' batch {} loss: {}'.format(i + 1, last_loss))
tb_x = epoch_index * len(training_loader) + i + 1
tb_writer.add_scalar('Loss/train', last_loss, tb_x)
running_loss = 0.
return last_loss
Per-Epoch Activity#
There are a couple of things we’ll want to do once per epoch:
Perform validation by checking our relative loss on a set of data that was not used for training, and report this
Save a copy of the model
Here, we’ll do our reporting in TensorBoard. This will require going to the command line to start TensorBoard, and opening it in another browser tab.
# Initializing in a separate cell so we can easily add more epochs to the same run
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp))
epoch_number = 0
EPOCHS = 5
best_vloss = 1_000_000.
for epoch in range(EPOCHS):
print('EPOCH {}:'.format(epoch_number + 1))
# Make sure gradient tracking is on, and do a pass over the data
model.train(True)
avg_loss = train_one_epoch(epoch_number, writer)
running_vloss = 0.0
# Set the model to evaluation mode, disabling dropout and using population
# statistics for batch normalization.
model.eval()
# Disable gradient computation and reduce memory consumption.
with torch.no_grad():
for i, vdata in enumerate(validation_loader):
vinputs, vlabels = vdata
voutputs = model(vinputs)
vloss = loss_fn(voutputs, vlabels)
running_vloss += vloss
avg_vloss = running_vloss / (i + 1)
print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))
# Log the running loss averaged per batch
# for both training and validation
writer.add_scalars('Training vs. Validation Loss',
{ 'Training' : avg_loss, 'Validation' : avg_vloss },
epoch_number + 1)
writer.flush()
# Track best performance, and save the model's state
if avg_vloss < best_vloss:
best_vloss = avg_vloss
model_path = 'model_{}_{}'.format(timestamp, epoch_number)
torch.save(model.state_dict(), model_path)
epoch_number += 1
EPOCH 1:
batch 1000 loss: 1.784976828455925
batch 2000 loss: 0.8263223152449355
batch 3000 loss: 0.6916487076813355
batch 4000 loss: 0.6096141297444702
batch 5000 loss: 0.569686681009829
batch 6000 loss: 0.5379622910583858
batch 7000 loss: 0.5362045865270775
batch 8000 loss: 0.5039491483343299
batch 9000 loss: 0.4627635038787266
batch 10000 loss: 0.46716841775504875
batch 11000 loss: 0.42577900271408725
batch 12000 loss: 0.4293474618118489
batch 13000 loss: 0.41360994360793846
batch 14000 loss: 0.4013340015343856
batch 15000 loss: 0.394765554038866
LOSS train 0.394765554038866 valid 0.422852098941803
EPOCH 2:
batch 1000 loss: 0.38278596836034556
batch 2000 loss: 0.38892397124494893
batch 3000 loss: 0.3666358025862137
batch 4000 loss: 0.3774275084128603
batch 5000 loss: 0.3842072506857658
batch 6000 loss: 0.3682632900656899
batch 7000 loss: 0.3571386761087924
batch 8000 loss: 0.3769625986551691
batch 9000 loss: 0.34585123410601226
batch 10000 loss: 0.349092586322964
batch 11000 loss: 0.332665092885989
batch 12000 loss: 0.33922079248694353
batch 13000 loss: 0.3508776856112818
batch 14000 loss: 0.3217826355858124
batch 15000 loss: 0.3316082783928578
LOSS train 0.3316082783928578 valid 0.3533305525779724
EPOCH 3:
batch 1000 loss: 0.3101922840481275
batch 2000 loss: 0.32035058495154956
batch 3000 loss: 0.32936868814887565
batch 4000 loss: 0.3129736438259133
batch 5000 loss: 0.3316463240184603
batch 6000 loss: 0.33540282097583984
batch 7000 loss: 0.30903379454429525
batch 8000 loss: 0.3482079836736375
batch 9000 loss: 0.29847264505916976
batch 10000 loss: 0.3172011704890465
batch 11000 loss: 0.3049195747375343
batch 12000 loss: 0.297806606514172
batch 13000 loss: 0.30823753245500846
batch 14000 loss: 0.32453707323713754
batch 15000 loss: 0.30613208849704826
LOSS train 0.30613208849704826 valid 0.3252071142196655
EPOCH 4:
batch 1000 loss: 0.3054257575027441
batch 2000 loss: 0.298001502304327
batch 3000 loss: 0.3014335927081065
batch 4000 loss: 0.29491765104489603
batch 5000 loss: 0.2821206657881139
batch 6000 loss: 0.2905257716884571
batch 7000 loss: 0.2931791045982227
batch 8000 loss: 0.308601964775251
batch 9000 loss: 0.304633750305773
batch 10000 loss: 0.2945715122923066
batch 11000 loss: 0.2919909181561852
batch 12000 loss: 0.3059344952407264
batch 13000 loss: 0.28693745871676946
batch 14000 loss: 0.26961352903507757
batch 15000 loss: 0.2736693370296089
LOSS train 0.2736693370296089 valid 0.3376033306121826
EPOCH 5:
batch 1000 loss: 0.26626920011397304
batch 2000 loss: 0.2704936440042184
batch 3000 loss: 0.2808960044778214
batch 4000 loss: 0.27183746823063937
batch 5000 loss: 0.27171938311338456
batch 6000 loss: 0.27300490664795507
batch 7000 loss: 0.2686056153687168
batch 8000 loss: 0.26906107153145875
batch 9000 loss: 0.28043227519997527
batch 10000 loss: 0.2656501787790912
batch 11000 loss: 0.27493628575908213
batch 12000 loss: 0.28932940587129635
batch 13000 loss: 0.27727131318160353
batch 14000 loss: 0.28784957085908175
batch 15000 loss: 0.28634940097837536
LOSS train 0.28634940097837536 valid 0.3038865923881531
To load a saved version of the model:
saved_model = GarmentClassifier()
saved_model.load_state_dict(torch.load(PATH))
Once you’ve loaded the model, it’s ready for whatever you need it for - more training, inference, or analysis.
Note that if your model has constructor parameters that affect model structure, you’ll need to provide them and configure the model identically to the state in which it was saved.
Other Resources#
Docs on the data utilities, including Dataset and DataLoader, at pytorch.org
A note on the use of pinned memory for GPU training
Documentation on the datasets available in TorchVision, TorchText, and TorchAudio
Documentation on the loss functions available in PyTorch
Documentation on the torch.optim package, which includes optimizers and related tools, such as learning rate scheduling
A detailed tutorial on saving and loading models
The Tutorials section of pytorch.org contains tutorials on a broad variety of training tasks, including classification in different domains, generative adversarial networks, reinforcement learning, and more
Total running time of the script: (2 minutes 56.288 seconds)