Note
Go to the end to download the full example code.
Introduction || Tensors || Autograd || Building Models || TensorBoard Support || Training Models || Model Understanding
Training with PyTorch#
Created On: Nov 30, 2021 | Last Updated: May 31, 2023 | Last Verified: Nov 05, 2024
Follow along with the video below or on youtube.
Introduction#
In past videos, we’ve discussed and demonstrated:
Building models with the neural network layers and functions of the torch.nn module
The mechanics of automated gradient computation, which is central to gradient-based model training
Using TensorBoard to visualize training progress and other activities
In this video, we’ll be adding some new tools to your inventory:
We’ll get familiar with the dataset and dataloader abstractions, and how they ease the process of feeding data to your model during a training loop
We’ll discuss specific loss functions and when to use them
We’ll look at PyTorch optimizers, which implement algorithms to adjust model weights based on the outcome of a loss function
Finally, we’ll pull all of these together and see a full PyTorch training loop in action.
Dataset and DataLoader#
The Dataset and DataLoader classes encapsulate the process of
pulling your data from storage and exposing it to your training loop in
batches.
The Dataset is responsible for accessing and processing single
instances of data.
The DataLoader pulls instances of data from the Dataset (either
automatically or with a sampler that you define), collects them in
batches, and returns them for consumption by your training loop. The
DataLoader works with all kinds of datasets, regardless of the type
of data they contain.
For this tutorial, we’ll be using the Fashion-MNIST dataset provided by
TorchVision. We use torchvision.transforms.Normalize() to
zero-center and normalize the distribution of the image tile content,
and download both training and validation data splits.
import torch
import torchvision
import torchvision.transforms as transforms
# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))])
# Create datasets for training & validation, download if necessary
training_set = torchvision.datasets.FashionMNIST('./data', train=True, transform=transform, download=True)
validation_set = torchvision.datasets.FashionMNIST('./data', train=False, transform=transform, download=True)
# Create data loaders for our datasets; shuffle for training, not for validation
training_loader = torch.utils.data.DataLoader(training_set, batch_size=4, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=4, shuffle=False)
# Class labels
classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')
# Report split sizes
print('Training set has {} instances'.format(len(training_set)))
print('Validation set has {} instances'.format(len(validation_set)))
0%| | 0.00/26.4M [00:00<?, ?B/s]
0%| | 65.5k/26.4M [00:00<01:12, 362kB/s]
1%| | 197k/26.4M [00:00<00:45, 574kB/s]
3%|▎ | 754k/26.4M [00:00<00:14, 1.71MB/s]
11%|█ | 2.95M/26.4M [00:00<00:04, 5.81MB/s]
30%|██▉ | 7.83M/26.4M [00:00<00:01, 13.4MB/s]
51%|█████ | 13.5M/26.4M [00:01<00:00, 19.4MB/s]
72%|███████▏ | 19.0M/26.4M [00:01<00:00, 23.0MB/s]
89%|████████▉ | 23.6M/26.4M [00:01<00:00, 27.7MB/s]
100%|██████████| 26.4M/26.4M [00:01<00:00, 18.0MB/s]
0%| | 0.00/29.5k [00:00<?, ?B/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 325kB/s]
0%| | 0.00/4.42M [00:00<?, ?B/s]
1%|▏ | 65.5k/4.42M [00:00<00:12, 356kB/s]
5%|▌ | 229k/4.42M [00:00<00:06, 672kB/s]
20%|██ | 885k/4.42M [00:00<00:01, 1.99MB/s]
62%|██████▏ | 2.72M/4.42M [00:00<00:00, 5.10MB/s]
100%|██████████| 4.42M/4.42M [00:00<00:00, 5.95MB/s]
0%| | 0.00/5.15k [00:00<?, ?B/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 51.7MB/s]
Training set has 60000 instances
Validation set has 10000 instances
As always, let’s visualize the data as a sanity check:
import matplotlib.pyplot as plt
import numpy as np
# Helper function for inline image display
def matplotlib_imshow(img, one_channel=False):
if one_channel:
img = img.mean(dim=0)
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
if one_channel:
plt.imshow(npimg, cmap="Greys")
else:
plt.imshow(np.transpose(npimg, (1, 2, 0)))
dataiter = iter(training_loader)
images, labels = next(dataiter)
# Create a grid from the images and show them
img_grid = torchvision.utils.make_grid(images)
matplotlib_imshow(img_grid, one_channel=True)
print(' '.join(classes[labels[j]] for j in range(4)))

Sandal Sneaker Shirt Bag
The Model#
The model we’ll use in this example is a variant of LeNet-5 - it should be familiar if you’ve watched the previous videos in this series.
import torch.nn as nn
import torch.nn.functional as F
# PyTorch models inherit from torch.nn.Module
class GarmentClassifier(nn.Module):
def __init__(self):
super(GarmentClassifier, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 4 * 4, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 4 * 4)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
model = GarmentClassifier()
Loss Function#
For this example, we’ll be using a cross-entropy loss. For demonstration purposes, we’ll create batches of dummy output and label values, run them through the loss function, and examine the result.
loss_fn = torch.nn.CrossEntropyLoss()
# NB: Loss functions expect data in batches, so we're creating batches of 4
# Represents the model's confidence in each of the 10 classes for a given input
dummy_outputs = torch.rand(4, 10)
# Represents the correct class among the 10 being tested
dummy_labels = torch.tensor([1, 5, 3, 7])
print(dummy_outputs)
print(dummy_labels)
loss = loss_fn(dummy_outputs, dummy_labels)
print('Total loss for this batch: {}'.format(loss.item()))
tensor([[0.5981, 0.7205, 0.4472, 0.4691, 0.1565, 0.5347, 0.4308, 0.1182, 0.9646,
0.4539],
[0.6230, 0.4794, 0.2207, 0.2924, 0.7148, 0.8645, 0.5875, 0.5251, 0.6756,
0.0916],
[0.0501, 0.7904, 0.7441, 0.5225, 0.3061, 0.6760, 0.3924, 0.6372, 0.5151,
0.8732],
[0.2018, 0.5311, 0.8389, 0.1922, 0.0745, 0.7502, 0.9822, 0.4657, 0.7697,
0.1901]])
tensor([1, 5, 3, 7])
Total loss for this batch: 2.2027742862701416
Optimizer#
For this example, we’ll be using simple stochastic gradient descent with momentum.
It can be instructive to try some variations on this optimization scheme:
Learning rate determines the size of the steps the optimizer takes. What does a different learning rate do to the your training results, in terms of accuracy and convergence time?
Momentum nudges the optimizer in the direction of strongest gradient over multiple steps. What does changing this value do to your results?
Try some different optimization algorithms, such as averaged SGD, Adagrad, or Adam. How do your results differ?
# Optimizers specified in the torch.optim package
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
The Training Loop#
Below, we have a function that performs one training epoch. It enumerates data from the DataLoader, and on each pass of the loop does the following:
Gets a batch of training data from the DataLoader
Zeros the optimizer’s gradients
Performs an inference - that is, gets predictions from the model for an input batch
Calculates the loss for that set of predictions vs. the labels on the dataset
Calculates the backward gradients over the learning weights
Tells the optimizer to perform one learning step - that is, adjust the model’s learning weights based on the observed gradients for this batch, according to the optimization algorithm we chose
It reports on the loss for every 1000 batches.
Finally, it reports the average per-batch loss for the last 1000 batches, for comparison with a validation run
def train_one_epoch(epoch_index, tb_writer):
running_loss = 0.
last_loss = 0.
# Here, we use enumerate(training_loader) instead of
# iter(training_loader) so that we can track the batch
# index and do some intra-epoch reporting
for i, data in enumerate(training_loader):
# Every data instance is an input + label pair
inputs, labels = data
# Zero your gradients for every batch!
optimizer.zero_grad()
# Make predictions for this batch
outputs = model(inputs)
# Compute the loss and its gradients
loss = loss_fn(outputs, labels)
loss.backward()
# Adjust learning weights
optimizer.step()
# Gather data and report
running_loss += loss.item()
if i % 1000 == 999:
last_loss = running_loss / 1000 # loss per batch
print(' batch {} loss: {}'.format(i + 1, last_loss))
tb_x = epoch_index * len(training_loader) + i + 1
tb_writer.add_scalar('Loss/train', last_loss, tb_x)
running_loss = 0.
return last_loss
Per-Epoch Activity#
There are a couple of things we’ll want to do once per epoch:
Perform validation by checking our relative loss on a set of data that was not used for training, and report this
Save a copy of the model
Here, we’ll do our reporting in TensorBoard. This will require going to the command line to start TensorBoard, and opening it in another browser tab.
# Initializing in a separate cell so we can easily add more epochs to the same run
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp))
epoch_number = 0
EPOCHS = 5
best_vloss = 1_000_000.
for epoch in range(EPOCHS):
print('EPOCH {}:'.format(epoch_number + 1))
# Make sure gradient tracking is on, and do a pass over the data
model.train(True)
avg_loss = train_one_epoch(epoch_number, writer)
running_vloss = 0.0
# Set the model to evaluation mode, disabling dropout and using population
# statistics for batch normalization.
model.eval()
# Disable gradient computation and reduce memory consumption.
with torch.no_grad():
for i, vdata in enumerate(validation_loader):
vinputs, vlabels = vdata
voutputs = model(vinputs)
vloss = loss_fn(voutputs, vlabels)
running_vloss += vloss
avg_vloss = running_vloss / (i + 1)
print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))
# Log the running loss averaged per batch
# for both training and validation
writer.add_scalars('Training vs. Validation Loss',
{ 'Training' : avg_loss, 'Validation' : avg_vloss },
epoch_number + 1)
writer.flush()
# Track best performance, and save the model's state
if avg_vloss < best_vloss:
best_vloss = avg_vloss
model_path = 'model_{}_{}'.format(timestamp, epoch_number)
torch.save(model.state_dict(), model_path)
epoch_number += 1
EPOCH 1:
batch 1000 loss: 1.666409859918058
batch 2000 loss: 0.8216810051053762
batch 3000 loss: 0.693145078105852
batch 4000 loss: 0.6443965511168354
batch 5000 loss: 0.6123742864592933
batch 6000 loss: 0.5695766103928909
batch 7000 loss: 0.5409413252712693
batch 8000 loss: 0.5383153433622793
batch 9000 loss: 0.48026449825975576
batch 10000 loss: 0.4591459574009059
batch 11000 loss: 0.45217856835146086
batch 12000 loss: 0.431060717097309
batch 13000 loss: 0.41652981538244055
batch 14000 loss: 0.435001613863511
batch 15000 loss: 0.4117226452493924
LOSS train 0.4117226452493924 valid 0.42531269788742065
EPOCH 2:
batch 1000 loss: 0.3943638932242757
batch 2000 loss: 0.39510620032442967
batch 3000 loss: 0.40187308340048183
batch 4000 loss: 0.41561483964993384
batch 5000 loss: 0.37135440780574575
batch 6000 loss: 0.3847427979120985
batch 7000 loss: 0.3660853395376471
batch 8000 loss: 0.3599262051352125
batch 9000 loss: 0.36613601676898544
batch 10000 loss: 0.34619443843280895
batch 11000 loss: 0.3421523532573119
batch 12000 loss: 0.37944928950941537
batch 13000 loss: 0.3445565646337418
batch 14000 loss: 0.3472710616480363
batch 15000 loss: 0.3482665803800919
LOSS train 0.3482665803800919 valid 0.37191668152809143
EPOCH 3:
batch 1000 loss: 0.35298623689083614
batch 2000 loss: 0.31526475692175154
batch 3000 loss: 0.354445223361603
batch 4000 loss: 0.31954076824391087
batch 5000 loss: 0.30167409399730966
batch 6000 loss: 0.32178128572105563
batch 7000 loss: 0.31245879809299365
batch 8000 loss: 0.3102076395740296
batch 9000 loss: 0.3193566365780716
batch 10000 loss: 0.3245317395089805
batch 11000 loss: 0.32724233834208283
batch 12000 loss: 0.3273154704665576
batch 13000 loss: 0.3198279506397084
batch 14000 loss: 0.3135476417306054
batch 15000 loss: 0.31637832522210374
LOSS train 0.31637832522210374 valid 0.3359675407409668
EPOCH 4:
batch 1000 loss: 0.27941275065656734
batch 2000 loss: 0.2823940862530035
batch 3000 loss: 0.2894134281675447
batch 4000 loss: 0.3015546747631597
batch 5000 loss: 0.28293730535544453
batch 6000 loss: 0.2941953631842043
batch 7000 loss: 0.3244606865464448
batch 8000 loss: 0.2946359610656218
batch 9000 loss: 0.3051677185113658
batch 10000 loss: 0.2765467494608965
batch 11000 loss: 0.31629972430641645
batch 12000 loss: 0.3217379439852521
batch 13000 loss: 0.2986907337167504
batch 14000 loss: 0.2571812377775459
batch 15000 loss: 0.2835259589429043
LOSS train 0.2835259589429043 valid 0.33261457085609436
EPOCH 5:
batch 1000 loss: 0.27721858343805705
batch 2000 loss: 0.2762320360558242
batch 3000 loss: 0.2741196601182746
batch 4000 loss: 0.27815906952939257
batch 5000 loss: 0.2765891040311908
batch 6000 loss: 0.28914274197602935
batch 7000 loss: 0.27360277335835415
batch 8000 loss: 0.2811103402964691
batch 9000 loss: 0.2858065232049412
batch 10000 loss: 0.25068630761879285
batch 11000 loss: 0.2620843322443907
batch 12000 loss: 0.29563811091540265
batch 13000 loss: 0.2757980781155493
batch 14000 loss: 0.27335850994923383
batch 15000 loss: 0.2731545760410836
LOSS train 0.2731545760410836 valid 0.3064103424549103
To load a saved version of the model:
saved_model = GarmentClassifier()
saved_model.load_state_dict(torch.load(PATH))
Once you’ve loaded the model, it’s ready for whatever you need it for - more training, inference, or analysis.
Note that if your model has constructor parameters that affect model structure, you’ll need to provide them and configure the model identically to the state in which it was saved.
Other Resources#
Docs on the data utilities, including Dataset and DataLoader, at pytorch.org
A note on the use of pinned memory for GPU training
Documentation on the datasets available in TorchVision, TorchText, and TorchAudio
Documentation on the loss functions available in PyTorch
Documentation on the torch.optim package, which includes optimizers and related tools, such as learning rate scheduling
A detailed tutorial on saving and loading models
The Tutorials section of pytorch.org contains tutorials on a broad variety of training tasks, including classification in different domains, generative adversarial networks, reinforcement learning, and more
Total running time of the script: (3 minutes 3.170 seconds)