Note
Click here to download the full example code
Introduction || Tensors || Autograd || Building Models || TensorBoard Support || Training Models || Model Understanding
Training with PyTorch¶
Created On: Nov 30, 2021 | Last Updated: May 31, 2023 | Last Verified: Nov 05, 2024
Follow along with the video below or on youtube.
Introduction¶
In past videos, we’ve discussed and demonstrated:
Building models with the neural network layers and functions of the torch.nn module
The mechanics of automated gradient computation, which is central to gradient-based model training
Using TensorBoard to visualize training progress and other activities
In this video, we’ll be adding some new tools to your inventory:
We’ll get familiar with the dataset and dataloader abstractions, and how they ease the process of feeding data to your model during a training loop
We’ll discuss specific loss functions and when to use them
We’ll look at PyTorch optimizers, which implement algorithms to adjust model weights based on the outcome of a loss function
Finally, we’ll pull all of these together and see a full PyTorch training loop in action.
Dataset and DataLoader¶
The Dataset
and DataLoader
classes encapsulate the process of
pulling your data from storage and exposing it to your training loop in
batches.
The Dataset
is responsible for accessing and processing single
instances of data.
The DataLoader
pulls instances of data from the Dataset
(either
automatically or with a sampler that you define), collects them in
batches, and returns them for consumption by your training loop. The
DataLoader
works with all kinds of datasets, regardless of the type
of data they contain.
For this tutorial, we’ll be using the Fashion-MNIST dataset provided by
TorchVision. We use torchvision.transforms.Normalize()
to
zero-center and normalize the distribution of the image tile content,
and download both training and validation data splits.
import torch
import torchvision
import torchvision.transforms as transforms
# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))])
# Create datasets for training & validation, download if necessary
training_set = torchvision.datasets.FashionMNIST('./data', train=True, transform=transform, download=True)
validation_set = torchvision.datasets.FashionMNIST('./data', train=False, transform=transform, download=True)
# Create data loaders for our datasets; shuffle for training, not for validation
training_loader = torch.utils.data.DataLoader(training_set, batch_size=4, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=4, shuffle=False)
# Class labels
classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')
# Report split sizes
print('Training set has {} instances'.format(len(training_set)))
print('Validation set has {} instances'.format(len(validation_set)))
0%| | 0.00/26.4M [00:00<?, ?B/s]
0%| | 65.5k/26.4M [00:00<01:12, 362kB/s]
1%| | 197k/26.4M [00:00<00:35, 729kB/s]
2%|1 | 492k/26.4M [00:00<00:20, 1.28MB/s]
6%|6 | 1.64M/26.4M [00:00<00:05, 4.17MB/s]
15%|#4 | 3.83M/26.4M [00:00<00:02, 8.07MB/s]
35%|###4 | 9.18M/26.4M [00:00<00:00, 19.2MB/s]
49%|####9 | 13.1M/26.4M [00:00<00:00, 21.3MB/s]
62%|######2 | 16.4M/26.4M [00:01<00:00, 23.9MB/s]
80%|#######9 | 21.0M/26.4M [00:01<00:00, 29.5MB/s]
95%|#########5| 25.2M/26.4M [00:01<00:00, 28.2MB/s]
100%|##########| 26.4M/26.4M [00:01<00:00, 19.1MB/s]
0%| | 0.00/29.5k [00:00<?, ?B/s]
100%|##########| 29.5k/29.5k [00:00<00:00, 326kB/s]
0%| | 0.00/4.42M [00:00<?, ?B/s]
1%|1 | 65.5k/4.42M [00:00<00:12, 361kB/s]
4%|3 | 164k/4.42M [00:00<00:06, 634kB/s]
10%|9 | 426k/4.42M [00:00<00:03, 1.10MB/s]
30%|### | 1.34M/4.42M [00:00<00:00, 3.53MB/s]
77%|#######7 | 3.41M/4.42M [00:00<00:00, 7.14MB/s]
100%|##########| 4.42M/4.42M [00:00<00:00, 6.08MB/s]
0%| | 0.00/5.15k [00:00<?, ?B/s]
100%|##########| 5.15k/5.15k [00:00<00:00, 50.8MB/s]
Training set has 60000 instances
Validation set has 10000 instances
As always, let’s visualize the data as a sanity check:
import matplotlib.pyplot as plt
import numpy as np
# Helper function for inline image display
def matplotlib_imshow(img, one_channel=False):
if one_channel:
img = img.mean(dim=0)
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
if one_channel:
plt.imshow(npimg, cmap="Greys")
else:
plt.imshow(np.transpose(npimg, (1, 2, 0)))
dataiter = iter(training_loader)
images, labels = next(dataiter)
# Create a grid from the images and show them
img_grid = torchvision.utils.make_grid(images)
matplotlib_imshow(img_grid, one_channel=True)
print(' '.join(classes[labels[j]] for j in range(4)))

Ankle Boot Sandal Bag Sneaker
The Model¶
The model we’ll use in this example is a variant of LeNet-5 - it should be familiar if you’ve watched the previous videos in this series.
import torch.nn as nn
import torch.nn.functional as F
# PyTorch models inherit from torch.nn.Module
class GarmentClassifier(nn.Module):
def __init__(self):
super(GarmentClassifier, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 4 * 4, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 4 * 4)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
model = GarmentClassifier()
Loss Function¶
For this example, we’ll be using a cross-entropy loss. For demonstration purposes, we’ll create batches of dummy output and label values, run them through the loss function, and examine the result.
loss_fn = torch.nn.CrossEntropyLoss()
# NB: Loss functions expect data in batches, so we're creating batches of 4
# Represents the model's confidence in each of the 10 classes for a given input
dummy_outputs = torch.rand(4, 10)
# Represents the correct class among the 10 being tested
dummy_labels = torch.tensor([1, 5, 3, 7])
print(dummy_outputs)
print(dummy_labels)
loss = loss_fn(dummy_outputs, dummy_labels)
print('Total loss for this batch: {}'.format(loss.item()))
tensor([[0.3697, 0.1103, 0.2389, 0.2846, 0.6983, 0.8068, 0.9028, 0.4928, 0.9992,
0.7777],
[0.6622, 0.9821, 0.6700, 0.9801, 0.6631, 0.0178, 0.3756, 0.7354, 0.4474,
0.0722],
[0.3523, 0.8827, 0.5199, 0.3776, 0.9723, 0.9845, 0.7209, 0.8129, 0.5450,
0.1174],
[0.8556, 0.6397, 0.4264, 0.5046, 0.0734, 0.5366, 0.2868, 0.6521, 0.8654,
0.4339]])
tensor([1, 5, 3, 7])
Total loss for this batch: 2.62253475189209
Optimizer¶
For this example, we’ll be using simple stochastic gradient descent with momentum.
It can be instructive to try some variations on this optimization scheme:
Learning rate determines the size of the steps the optimizer takes. What does a different learning rate do to the your training results, in terms of accuracy and convergence time?
Momentum nudges the optimizer in the direction of strongest gradient over multiple steps. What does changing this value do to your results?
Try some different optimization algorithms, such as averaged SGD, Adagrad, or Adam. How do your results differ?
# Optimizers specified in the torch.optim package
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
The Training Loop¶
Below, we have a function that performs one training epoch. It enumerates data from the DataLoader, and on each pass of the loop does the following:
Gets a batch of training data from the DataLoader
Zeros the optimizer’s gradients
Performs an inference - that is, gets predictions from the model for an input batch
Calculates the loss for that set of predictions vs. the labels on the dataset
Calculates the backward gradients over the learning weights
Tells the optimizer to perform one learning step - that is, adjust the model’s learning weights based on the observed gradients for this batch, according to the optimization algorithm we chose
It reports on the loss for every 1000 batches.
Finally, it reports the average per-batch loss for the last 1000 batches, for comparison with a validation run
def train_one_epoch(epoch_index, tb_writer):
running_loss = 0.
last_loss = 0.
# Here, we use enumerate(training_loader) instead of
# iter(training_loader) so that we can track the batch
# index and do some intra-epoch reporting
for i, data in enumerate(training_loader):
# Every data instance is an input + label pair
inputs, labels = data
# Zero your gradients for every batch!
optimizer.zero_grad()
# Make predictions for this batch
outputs = model(inputs)
# Compute the loss and its gradients
loss = loss_fn(outputs, labels)
loss.backward()
# Adjust learning weights
optimizer.step()
# Gather data and report
running_loss += loss.item()
if i % 1000 == 999:
last_loss = running_loss / 1000 # loss per batch
print(' batch {} loss: {}'.format(i + 1, last_loss))
tb_x = epoch_index * len(training_loader) + i + 1
tb_writer.add_scalar('Loss/train', last_loss, tb_x)
running_loss = 0.
return last_loss
Per-Epoch Activity¶
There are a couple of things we’ll want to do once per epoch:
Perform validation by checking our relative loss on a set of data that was not used for training, and report this
Save a copy of the model
Here, we’ll do our reporting in TensorBoard. This will require going to the command line to start TensorBoard, and opening it in another browser tab.
# Initializing in a separate cell so we can easily add more epochs to the same run
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp))
epoch_number = 0
EPOCHS = 5
best_vloss = 1_000_000.
for epoch in range(EPOCHS):
print('EPOCH {}:'.format(epoch_number + 1))
# Make sure gradient tracking is on, and do a pass over the data
model.train(True)
avg_loss = train_one_epoch(epoch_number, writer)
running_vloss = 0.0
# Set the model to evaluation mode, disabling dropout and using population
# statistics for batch normalization.
model.eval()
# Disable gradient computation and reduce memory consumption.
with torch.no_grad():
for i, vdata in enumerate(validation_loader):
vinputs, vlabels = vdata
voutputs = model(vinputs)
vloss = loss_fn(voutputs, vlabels)
running_vloss += vloss
avg_vloss = running_vloss / (i + 1)
print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))
# Log the running loss averaged per batch
# for both training and validation
writer.add_scalars('Training vs. Validation Loss',
{ 'Training' : avg_loss, 'Validation' : avg_vloss },
epoch_number + 1)
writer.flush()
# Track best performance, and save the model's state
if avg_vloss < best_vloss:
best_vloss = avg_vloss
model_path = 'model_{}_{}'.format(timestamp, epoch_number)
torch.save(model.state_dict(), model_path)
epoch_number += 1
EPOCH 1:
batch 1000 loss: 2.078445885926485
batch 2000 loss: 0.9853184329010546
batch 3000 loss: 0.7737890719217249
batch 4000 loss: 0.6859524527182802
batch 5000 loss: 0.6186456295056268
batch 6000 loss: 0.5717601936603897
batch 7000 loss: 0.5374739476540126
batch 8000 loss: 0.5170167652661912
batch 9000 loss: 0.4631443543983623
batch 10000 loss: 0.47267150053571094
batch 11000 loss: 0.4634251589034684
batch 12000 loss: 0.4522177288719686
batch 13000 loss: 0.43082623432995754
batch 14000 loss: 0.42096412543102635
batch 15000 loss: 0.43448481692990754
LOSS train 0.43448481692990754 valid 0.4238249659538269
EPOCH 2:
batch 1000 loss: 0.3991063394710072
batch 2000 loss: 0.4053242630013847
batch 3000 loss: 0.39313117962612887
batch 4000 loss: 0.40031174979098433
batch 5000 loss: 0.37515016218155506
batch 6000 loss: 0.3772752701257559
batch 7000 loss: 0.37141740097344156
batch 8000 loss: 0.34740131870593177
batch 9000 loss: 0.37146376077127935
batch 10000 loss: 0.3786556771292235
batch 11000 loss: 0.3510120540350035
batch 12000 loss: 0.3691932062833221
batch 13000 loss: 0.3275497554614267
batch 14000 loss: 0.3545808146480704
batch 15000 loss: 0.3416661732759385
LOSS train 0.3416661732759385 valid 0.36629074811935425
EPOCH 3:
batch 1000 loss: 0.34619883446252786
batch 2000 loss: 0.33732505791306905
batch 3000 loss: 0.3218378667862271
batch 4000 loss: 0.3382402218070056
batch 5000 loss: 0.30498756019500434
batch 6000 loss: 0.32820077318514085
batch 7000 loss: 0.3115696356798289
batch 8000 loss: 0.32188789568343784
batch 9000 loss: 0.3139526225671871
batch 10000 loss: 0.3200232334123284
batch 11000 loss: 0.3184244549650912
batch 12000 loss: 0.32671575834829125
batch 13000 loss: 0.32073535430858463
batch 14000 loss: 0.31405857127984926
batch 15000 loss: 0.32439225951711703
LOSS train 0.32439225951711703 valid 0.3356749415397644
EPOCH 4:
batch 1000 loss: 0.2872203590600111
batch 2000 loss: 0.301946465279485
batch 3000 loss: 0.29448101510899144
batch 4000 loss: 0.3018223494642589
batch 5000 loss: 0.30555817463664425
batch 6000 loss: 0.3044548723964253
batch 7000 loss: 0.29134536752272744
batch 8000 loss: 0.29482983102928484
batch 9000 loss: 0.2826406414659941
batch 10000 loss: 0.29214115747269537
batch 11000 loss: 0.29499360246781997
batch 12000 loss: 0.29350314297291696
batch 13000 loss: 0.2856313340542665
batch 14000 loss: 0.3208040994987605
batch 15000 loss: 0.3008514843018347
LOSS train 0.3008514843018347 valid 0.35053306818008423
EPOCH 5:
batch 1000 loss: 0.273013430654104
batch 2000 loss: 0.2721975499899418
batch 3000 loss: 0.2700386013448442
batch 4000 loss: 0.2823152292621453
batch 5000 loss: 0.26805072244314215
batch 6000 loss: 0.28127112275993565
batch 7000 loss: 0.2902480135697697
batch 8000 loss: 0.2787070024538261
batch 9000 loss: 0.27521627708787855
batch 10000 loss: 0.2811208541248252
batch 11000 loss: 0.2936549342118406
batch 12000 loss: 0.2634534901033694
batch 13000 loss: 0.28787386487085315
batch 14000 loss: 0.28128184450815386
batch 15000 loss: 0.2877686134627966
LOSS train 0.2877686134627966 valid 0.3017216622829437
To load a saved version of the model:
saved_model = GarmentClassifier()
saved_model.load_state_dict(torch.load(PATH))
Once you’ve loaded the model, it’s ready for whatever you need it for - more training, inference, or analysis.
Note that if your model has constructor parameters that affect model structure, you’ll need to provide them and configure the model identically to the state in which it was saved.
Other Resources¶
Docs on the data utilities, including Dataset and DataLoader, at pytorch.org
A note on the use of pinned memory for GPU training
Documentation on the datasets available in TorchVision, TorchText, and TorchAudio
Documentation on the loss functions available in PyTorch
Documentation on the torch.optim package, which includes optimizers and related tools, such as learning rate scheduling
A detailed tutorial on saving and loading models
The Tutorials section of pytorch.org contains tutorials on a broad variety of training tasks, including classification in different domains, generative adversarial networks, reinforcement learning, and more
Total running time of the script: ( 2 minutes 59.585 seconds)