Note
Go to the end to download the full example code.
Introduction || Tensors || Autograd || Building Models || TensorBoard Support || Training Models || Model Understanding
Training with PyTorch#
Created On: Nov 30, 2021 | Last Updated: May 31, 2023 | Last Verified: Nov 05, 2024
Follow along with the video below or on youtube.
Introduction#
In past videos, we’ve discussed and demonstrated:
Building models with the neural network layers and functions of the torch.nn module
The mechanics of automated gradient computation, which is central to gradient-based model training
Using TensorBoard to visualize training progress and other activities
In this video, we’ll be adding some new tools to your inventory:
We’ll get familiar with the dataset and dataloader abstractions, and how they ease the process of feeding data to your model during a training loop
We’ll discuss specific loss functions and when to use them
We’ll look at PyTorch optimizers, which implement algorithms to adjust model weights based on the outcome of a loss function
Finally, we’ll pull all of these together and see a full PyTorch training loop in action.
Dataset and DataLoader#
The Dataset
and DataLoader
classes encapsulate the process of
pulling your data from storage and exposing it to your training loop in
batches.
The Dataset
is responsible for accessing and processing single
instances of data.
The DataLoader
pulls instances of data from the Dataset
(either
automatically or with a sampler that you define), collects them in
batches, and returns them for consumption by your training loop. The
DataLoader
works with all kinds of datasets, regardless of the type
of data they contain.
For this tutorial, we’ll be using the Fashion-MNIST dataset provided by
TorchVision. We use torchvision.transforms.Normalize()
to
zero-center and normalize the distribution of the image tile content,
and download both training and validation data splits.
import torch
import torchvision
import torchvision.transforms as transforms
# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))])
# Create datasets for training & validation, download if necessary
training_set = torchvision.datasets.FashionMNIST('./data', train=True, transform=transform, download=True)
validation_set = torchvision.datasets.FashionMNIST('./data', train=False, transform=transform, download=True)
# Create data loaders for our datasets; shuffle for training, not for validation
training_loader = torch.utils.data.DataLoader(training_set, batch_size=4, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=4, shuffle=False)
# Class labels
classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')
# Report split sizes
print('Training set has {} instances'.format(len(training_set)))
print('Validation set has {} instances'.format(len(validation_set)))
0%| | 0.00/26.4M [00:00<?, ?B/s]
0%| | 65.5k/26.4M [00:00<01:12, 363kB/s]
1%| | 229k/26.4M [00:00<00:38, 680kB/s]
3%|▎ | 918k/26.4M [00:00<00:12, 2.10MB/s]
14%|█▍ | 3.64M/26.4M [00:00<00:03, 7.20MB/s]
36%|███▌ | 9.50M/26.4M [00:00<00:01, 16.3MB/s]
59%|█████▉ | 15.6M/26.4M [00:01<00:00, 22.1MB/s]
82%|████████▏ | 21.6M/26.4M [00:01<00:00, 25.7MB/s]
100%|██████████| 26.4M/26.4M [00:01<00:00, 19.3MB/s]
0%| | 0.00/29.5k [00:00<?, ?B/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 327kB/s]
0%| | 0.00/4.42M [00:00<?, ?B/s]
1%|▏ | 65.5k/4.42M [00:00<00:12, 362kB/s]
5%|▌ | 229k/4.42M [00:00<00:06, 682kB/s]
20%|██ | 885k/4.42M [00:00<00:01, 2.02MB/s]
79%|███████▊ | 3.47M/4.42M [00:00<00:00, 6.87MB/s]
100%|██████████| 4.42M/4.42M [00:00<00:00, 6.09MB/s]
0%| | 0.00/5.15k [00:00<?, ?B/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 53.4MB/s]
Training set has 60000 instances
Validation set has 10000 instances
As always, let’s visualize the data as a sanity check:
import matplotlib.pyplot as plt
import numpy as np
# Helper function for inline image display
def matplotlib_imshow(img, one_channel=False):
if one_channel:
img = img.mean(dim=0)
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
if one_channel:
plt.imshow(npimg, cmap="Greys")
else:
plt.imshow(np.transpose(npimg, (1, 2, 0)))
dataiter = iter(training_loader)
images, labels = next(dataiter)
# Create a grid from the images and show them
img_grid = torchvision.utils.make_grid(images)
matplotlib_imshow(img_grid, one_channel=True)
print(' '.join(classes[labels[j]] for j in range(4)))

Coat Ankle Boot Trouser Trouser
The Model#
The model we’ll use in this example is a variant of LeNet-5 - it should be familiar if you’ve watched the previous videos in this series.
import torch.nn as nn
import torch.nn.functional as F
# PyTorch models inherit from torch.nn.Module
class GarmentClassifier(nn.Module):
def __init__(self):
super(GarmentClassifier, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 4 * 4, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 4 * 4)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
model = GarmentClassifier()
Loss Function#
For this example, we’ll be using a cross-entropy loss. For demonstration purposes, we’ll create batches of dummy output and label values, run them through the loss function, and examine the result.
loss_fn = torch.nn.CrossEntropyLoss()
# NB: Loss functions expect data in batches, so we're creating batches of 4
# Represents the model's confidence in each of the 10 classes for a given input
dummy_outputs = torch.rand(4, 10)
# Represents the correct class among the 10 being tested
dummy_labels = torch.tensor([1, 5, 3, 7])
print(dummy_outputs)
print(dummy_labels)
loss = loss_fn(dummy_outputs, dummy_labels)
print('Total loss for this batch: {}'.format(loss.item()))
tensor([[0.5779, 0.8862, 0.4191, 0.0032, 0.2113, 0.9228, 0.4309, 0.5445, 0.1012,
0.9414],
[0.6555, 0.3088, 0.1644, 0.0469, 0.6615, 0.5882, 0.3748, 0.1855, 0.8543,
0.5478],
[0.7276, 0.8717, 0.1798, 0.8854, 0.5145, 0.7271, 0.3891, 0.2517, 0.4869,
0.4351],
[0.5138, 0.1533, 0.4548, 0.5427, 0.3486, 0.0081, 0.1881, 0.1080, 0.9436,
0.8507]])
tensor([1, 5, 3, 7])
Total loss for this batch: 2.199453830718994
Optimizer#
For this example, we’ll be using simple stochastic gradient descent with momentum.
It can be instructive to try some variations on this optimization scheme:
Learning rate determines the size of the steps the optimizer takes. What does a different learning rate do to the your training results, in terms of accuracy and convergence time?
Momentum nudges the optimizer in the direction of strongest gradient over multiple steps. What does changing this value do to your results?
Try some different optimization algorithms, such as averaged SGD, Adagrad, or Adam. How do your results differ?
# Optimizers specified in the torch.optim package
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
The Training Loop#
Below, we have a function that performs one training epoch. It enumerates data from the DataLoader, and on each pass of the loop does the following:
Gets a batch of training data from the DataLoader
Zeros the optimizer’s gradients
Performs an inference - that is, gets predictions from the model for an input batch
Calculates the loss for that set of predictions vs. the labels on the dataset
Calculates the backward gradients over the learning weights
Tells the optimizer to perform one learning step - that is, adjust the model’s learning weights based on the observed gradients for this batch, according to the optimization algorithm we chose
It reports on the loss for every 1000 batches.
Finally, it reports the average per-batch loss for the last 1000 batches, for comparison with a validation run
def train_one_epoch(epoch_index, tb_writer):
running_loss = 0.
last_loss = 0.
# Here, we use enumerate(training_loader) instead of
# iter(training_loader) so that we can track the batch
# index and do some intra-epoch reporting
for i, data in enumerate(training_loader):
# Every data instance is an input + label pair
inputs, labels = data
# Zero your gradients for every batch!
optimizer.zero_grad()
# Make predictions for this batch
outputs = model(inputs)
# Compute the loss and its gradients
loss = loss_fn(outputs, labels)
loss.backward()
# Adjust learning weights
optimizer.step()
# Gather data and report
running_loss += loss.item()
if i % 1000 == 999:
last_loss = running_loss / 1000 # loss per batch
print(' batch {} loss: {}'.format(i + 1, last_loss))
tb_x = epoch_index * len(training_loader) + i + 1
tb_writer.add_scalar('Loss/train', last_loss, tb_x)
running_loss = 0.
return last_loss
Per-Epoch Activity#
There are a couple of things we’ll want to do once per epoch:
Perform validation by checking our relative loss on a set of data that was not used for training, and report this
Save a copy of the model
Here, we’ll do our reporting in TensorBoard. This will require going to the command line to start TensorBoard, and opening it in another browser tab.
# Initializing in a separate cell so we can easily add more epochs to the same run
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp))
epoch_number = 0
EPOCHS = 5
best_vloss = 1_000_000.
for epoch in range(EPOCHS):
print('EPOCH {}:'.format(epoch_number + 1))
# Make sure gradient tracking is on, and do a pass over the data
model.train(True)
avg_loss = train_one_epoch(epoch_number, writer)
running_vloss = 0.0
# Set the model to evaluation mode, disabling dropout and using population
# statistics for batch normalization.
model.eval()
# Disable gradient computation and reduce memory consumption.
with torch.no_grad():
for i, vdata in enumerate(validation_loader):
vinputs, vlabels = vdata
voutputs = model(vinputs)
vloss = loss_fn(voutputs, vlabels)
running_vloss += vloss
avg_vloss = running_vloss / (i + 1)
print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))
# Log the running loss averaged per batch
# for both training and validation
writer.add_scalars('Training vs. Validation Loss',
{ 'Training' : avg_loss, 'Validation' : avg_vloss },
epoch_number + 1)
writer.flush()
# Track best performance, and save the model's state
if avg_vloss < best_vloss:
best_vloss = avg_vloss
model_path = 'model_{}_{}'.format(timestamp, epoch_number)
torch.save(model.state_dict(), model_path)
epoch_number += 1
EPOCH 1:
batch 1000 loss: 1.786096847176552
batch 2000 loss: 0.859170969389379
batch 3000 loss: 0.7190974873229862
batch 4000 loss: 0.6598526890622451
batch 5000 loss: 0.5838459139754996
batch 6000 loss: 0.5521430970021757
batch 7000 loss: 0.5351364195225761
batch 8000 loss: 0.539474343089154
batch 9000 loss: 0.5124938601325266
batch 10000 loss: 0.49151815713290126
batch 11000 loss: 0.4708568763874937
batch 12000 loss: 0.4611044364885893
batch 13000 loss: 0.4133654390580487
batch 14000 loss: 0.4402244464850519
batch 15000 loss: 0.4309590762491571
LOSS train 0.4309590762491571 valid 0.41458258032798767
EPOCH 2:
batch 1000 loss: 0.39941240309749265
batch 2000 loss: 0.4068603223847458
batch 3000 loss: 0.35848508665140255
batch 4000 loss: 0.3750318140688032
batch 5000 loss: 0.38891105032191264
batch 6000 loss: 0.3916742497025989
batch 7000 loss: 0.35125666220588025
batch 8000 loss: 0.3696044623917551
batch 9000 loss: 0.3618686895615247
batch 10000 loss: 0.3662503977565502
batch 11000 loss: 0.3596696388687415
batch 12000 loss: 0.36236808000149906
batch 13000 loss: 0.33359414375250346
batch 14000 loss: 0.3566772072846652
batch 15000 loss: 0.35831537446996664
LOSS train 0.35831537446996664 valid 0.3961953818798065
EPOCH 3:
batch 1000 loss: 0.326778023144132
batch 2000 loss: 0.3371908660522713
batch 3000 loss: 0.34427384962077484
batch 4000 loss: 0.3293056495572964
batch 5000 loss: 0.3102431019931537
batch 6000 loss: 0.30466453904053925
batch 7000 loss: 0.32623241228199
batch 8000 loss: 0.3249970707918692
batch 9000 loss: 0.3111401640658296
batch 10000 loss: 0.31191050394101694
batch 11000 loss: 0.3347715679293324
batch 12000 loss: 0.32604350845052976
batch 13000 loss: 0.31874223645444
batch 14000 loss: 0.32521826910832896
batch 15000 loss: 0.30239877930857256
LOSS train 0.30239877930857256 valid 0.33625584840774536
EPOCH 4:
batch 1000 loss: 0.3080627405752166
batch 2000 loss: 0.28084719391765245
batch 3000 loss: 0.29544671591219956
batch 4000 loss: 0.2860185973097123
batch 5000 loss: 0.27990782612843756
batch 6000 loss: 0.2890156138100192
batch 7000 loss: 0.321462659977944
batch 8000 loss: 0.28849853136204184
batch 9000 loss: 0.270612389343758
batch 10000 loss: 0.3026917371701129
batch 11000 loss: 0.2939817950121651
batch 12000 loss: 0.31887937623170726
batch 13000 loss: 0.3085348536039291
batch 14000 loss: 0.2730141546119776
batch 15000 loss: 0.277164637510281
LOSS train 0.277164637510281 valid 0.3341556787490845
EPOCH 5:
batch 1000 loss: 0.2676803221049486
batch 2000 loss: 0.27730777416354246
batch 3000 loss: 0.26299332616944593
batch 4000 loss: 0.28413415247360535
batch 5000 loss: 0.2799531579449431
batch 6000 loss: 0.2690072998491196
batch 7000 loss: 0.26681686138529037
batch 8000 loss: 0.2753645258527431
batch 9000 loss: 0.290746773472536
batch 10000 loss: 0.26916193184114307
batch 11000 loss: 0.27618473538951366
batch 12000 loss: 0.28250575076439055
batch 13000 loss: 0.2712034014643359
batch 14000 loss: 0.276366593589757
batch 15000 loss: 0.2559231778119738
LOSS train 0.2559231778119738 valid 0.3214151859283447
To load a saved version of the model:
saved_model = GarmentClassifier()
saved_model.load_state_dict(torch.load(PATH))
Once you’ve loaded the model, it’s ready for whatever you need it for - more training, inference, or analysis.
Note that if your model has constructor parameters that affect model structure, you’ll need to provide them and configure the model identically to the state in which it was saved.
Other Resources#
Docs on the data utilities, including Dataset and DataLoader, at pytorch.org
A note on the use of pinned memory for GPU training
Documentation on the datasets available in TorchVision, TorchText, and TorchAudio
Documentation on the loss functions available in PyTorch
Documentation on the torch.optim package, which includes optimizers and related tools, such as learning rate scheduling
A detailed tutorial on saving and loading models
The Tutorials section of pytorch.org contains tutorials on a broad variety of training tasks, including classification in different domains, generative adversarial networks, reinforcement learning, and more
Total running time of the script: (2 minutes 56.289 seconds)