Note
Go to the end to download the full example code.
Introduction || Tensors || Autograd || Building Models || TensorBoard Support || Training Models || Model Understanding
Training with PyTorch#
Created On: Nov 30, 2021 | Last Updated: May 31, 2023 | Last Verified: Nov 05, 2024
Follow along with the video below or on youtube.
Introduction#
In past videos, we’ve discussed and demonstrated:
Building models with the neural network layers and functions of the torch.nn module
The mechanics of automated gradient computation, which is central to gradient-based model training
Using TensorBoard to visualize training progress and other activities
In this video, we’ll be adding some new tools to your inventory:
We’ll get familiar with the dataset and dataloader abstractions, and how they ease the process of feeding data to your model during a training loop
We’ll discuss specific loss functions and when to use them
We’ll look at PyTorch optimizers, which implement algorithms to adjust model weights based on the outcome of a loss function
Finally, we’ll pull all of these together and see a full PyTorch training loop in action.
Dataset and DataLoader#
The Dataset
and DataLoader
classes encapsulate the process of
pulling your data from storage and exposing it to your training loop in
batches.
The Dataset
is responsible for accessing and processing single
instances of data.
The DataLoader
pulls instances of data from the Dataset
(either
automatically or with a sampler that you define), collects them in
batches, and returns them for consumption by your training loop. The
DataLoader
works with all kinds of datasets, regardless of the type
of data they contain.
For this tutorial, we’ll be using the Fashion-MNIST dataset provided by
TorchVision. We use torchvision.transforms.Normalize()
to
zero-center and normalize the distribution of the image tile content,
and download both training and validation data splits.
import torch
import torchvision
import torchvision.transforms as transforms
# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))])
# Create datasets for training & validation, download if necessary
training_set = torchvision.datasets.FashionMNIST('./data', train=True, transform=transform, download=True)
validation_set = torchvision.datasets.FashionMNIST('./data', train=False, transform=transform, download=True)
# Create data loaders for our datasets; shuffle for training, not for validation
training_loader = torch.utils.data.DataLoader(training_set, batch_size=4, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=4, shuffle=False)
# Class labels
classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')
# Report split sizes
print('Training set has {} instances'.format(len(training_set)))
print('Validation set has {} instances'.format(len(validation_set)))
0%| | 0.00/26.4M [00:00<?, ?B/s]
0%| | 65.5k/26.4M [00:00<01:12, 364kB/s]
1%| | 229k/26.4M [00:00<00:38, 683kB/s]
3%|▎ | 918k/26.4M [00:00<00:12, 2.11MB/s]
14%|█▍ | 3.64M/26.4M [00:00<00:02, 8.02MB/s]
26%|██▌ | 6.78M/26.4M [00:00<00:01, 12.8MB/s]
47%|████▋ | 12.6M/26.4M [00:00<00:00, 22.2MB/s]
60%|█████▉ | 15.8M/26.4M [00:01<00:00, 22.7MB/s]
81%|████████ | 21.4M/26.4M [00:01<00:00, 28.4MB/s]
94%|█████████▍| 24.9M/26.4M [00:01<00:00, 27.6MB/s]
100%|██████████| 26.4M/26.4M [00:01<00:00, 19.4MB/s]
0%| | 0.00/29.5k [00:00<?, ?B/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 326kB/s]
0%| | 0.00/4.42M [00:00<?, ?B/s]
1%|▏ | 65.5k/4.42M [00:00<00:12, 358kB/s]
5%|▌ | 229k/4.42M [00:00<00:06, 674kB/s]
21%|██ | 918k/4.42M [00:00<00:01, 2.08MB/s]
83%|████████▎ | 3.67M/4.42M [00:00<00:00, 7.19MB/s]
100%|██████████| 4.42M/4.42M [00:00<00:00, 6.03MB/s]
0%| | 0.00/5.15k [00:00<?, ?B/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 47.6MB/s]
Training set has 60000 instances
Validation set has 10000 instances
As always, let’s visualize the data as a sanity check:
import matplotlib.pyplot as plt
import numpy as np
# Helper function for inline image display
def matplotlib_imshow(img, one_channel=False):
if one_channel:
img = img.mean(dim=0)
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
if one_channel:
plt.imshow(npimg, cmap="Greys")
else:
plt.imshow(np.transpose(npimg, (1, 2, 0)))
dataiter = iter(training_loader)
images, labels = next(dataiter)
# Create a grid from the images and show them
img_grid = torchvision.utils.make_grid(images)
matplotlib_imshow(img_grid, one_channel=True)
print(' '.join(classes[labels[j]] for j in range(4)))

Pullover Dress Dress Ankle Boot
The Model#
The model we’ll use in this example is a variant of LeNet-5 - it should be familiar if you’ve watched the previous videos in this series.
import torch.nn as nn
import torch.nn.functional as F
# PyTorch models inherit from torch.nn.Module
class GarmentClassifier(nn.Module):
def __init__(self):
super(GarmentClassifier, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 4 * 4, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 4 * 4)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
model = GarmentClassifier()
Loss Function#
For this example, we’ll be using a cross-entropy loss. For demonstration purposes, we’ll create batches of dummy output and label values, run them through the loss function, and examine the result.
loss_fn = torch.nn.CrossEntropyLoss()
# NB: Loss functions expect data in batches, so we're creating batches of 4
# Represents the model's confidence in each of the 10 classes for a given input
dummy_outputs = torch.rand(4, 10)
# Represents the correct class among the 10 being tested
dummy_labels = torch.tensor([1, 5, 3, 7])
print(dummy_outputs)
print(dummy_labels)
loss = loss_fn(dummy_outputs, dummy_labels)
print('Total loss for this batch: {}'.format(loss.item()))
tensor([[0.5725, 0.0701, 0.7984, 0.7248, 0.6974, 0.7434, 0.6711, 0.4446, 0.0455,
0.7280],
[0.8428, 0.2333, 0.2198, 0.9558, 0.8582, 0.3675, 0.5048, 0.7045, 0.5341,
0.2858],
[0.7080, 0.1596, 0.6911, 0.0654, 0.7256, 0.5347, 0.0221, 0.6075, 0.6316,
0.1355],
[0.9750, 0.6543, 0.0877, 0.7354, 0.2457, 0.5542, 0.5423, 0.2040, 0.1801,
0.0383]])
tensor([1, 5, 3, 7])
Total loss for this batch: 2.6504311561584473
Optimizer#
For this example, we’ll be using simple stochastic gradient descent with momentum.
It can be instructive to try some variations on this optimization scheme:
Learning rate determines the size of the steps the optimizer takes. What does a different learning rate do to the your training results, in terms of accuracy and convergence time?
Momentum nudges the optimizer in the direction of strongest gradient over multiple steps. What does changing this value do to your results?
Try some different optimization algorithms, such as averaged SGD, Adagrad, or Adam. How do your results differ?
# Optimizers specified in the torch.optim package
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
The Training Loop#
Below, we have a function that performs one training epoch. It enumerates data from the DataLoader, and on each pass of the loop does the following:
Gets a batch of training data from the DataLoader
Zeros the optimizer’s gradients
Performs an inference - that is, gets predictions from the model for an input batch
Calculates the loss for that set of predictions vs. the labels on the dataset
Calculates the backward gradients over the learning weights
Tells the optimizer to perform one learning step - that is, adjust the model’s learning weights based on the observed gradients for this batch, according to the optimization algorithm we chose
It reports on the loss for every 1000 batches.
Finally, it reports the average per-batch loss for the last 1000 batches, for comparison with a validation run
def train_one_epoch(epoch_index, tb_writer):
running_loss = 0.
last_loss = 0.
# Here, we use enumerate(training_loader) instead of
# iter(training_loader) so that we can track the batch
# index and do some intra-epoch reporting
for i, data in enumerate(training_loader):
# Every data instance is an input + label pair
inputs, labels = data
# Zero your gradients for every batch!
optimizer.zero_grad()
# Make predictions for this batch
outputs = model(inputs)
# Compute the loss and its gradients
loss = loss_fn(outputs, labels)
loss.backward()
# Adjust learning weights
optimizer.step()
# Gather data and report
running_loss += loss.item()
if i % 1000 == 999:
last_loss = running_loss / 1000 # loss per batch
print(' batch {} loss: {}'.format(i + 1, last_loss))
tb_x = epoch_index * len(training_loader) + i + 1
tb_writer.add_scalar('Loss/train', last_loss, tb_x)
running_loss = 0.
return last_loss
Per-Epoch Activity#
There are a couple of things we’ll want to do once per epoch:
Perform validation by checking our relative loss on a set of data that was not used for training, and report this
Save a copy of the model
Here, we’ll do our reporting in TensorBoard. This will require going to the command line to start TensorBoard, and opening it in another browser tab.
# Initializing in a separate cell so we can easily add more epochs to the same run
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp))
epoch_number = 0
EPOCHS = 5
best_vloss = 1_000_000.
for epoch in range(EPOCHS):
print('EPOCH {}:'.format(epoch_number + 1))
# Make sure gradient tracking is on, and do a pass over the data
model.train(True)
avg_loss = train_one_epoch(epoch_number, writer)
running_vloss = 0.0
# Set the model to evaluation mode, disabling dropout and using population
# statistics for batch normalization.
model.eval()
# Disable gradient computation and reduce memory consumption.
with torch.no_grad():
for i, vdata in enumerate(validation_loader):
vinputs, vlabels = vdata
voutputs = model(vinputs)
vloss = loss_fn(voutputs, vlabels)
running_vloss += vloss
avg_vloss = running_vloss / (i + 1)
print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))
# Log the running loss averaged per batch
# for both training and validation
writer.add_scalars('Training vs. Validation Loss',
{ 'Training' : avg_loss, 'Validation' : avg_vloss },
epoch_number + 1)
writer.flush()
# Track best performance, and save the model's state
if avg_vloss < best_vloss:
best_vloss = avg_vloss
model_path = 'model_{}_{}'.format(timestamp, epoch_number)
torch.save(model.state_dict(), model_path)
epoch_number += 1
EPOCH 1:
batch 1000 loss: 2.2521325249671937
batch 2000 loss: 0.9854257740704343
batch 3000 loss: 0.6851327550658025
batch 4000 loss: 0.6135343031922821
batch 5000 loss: 0.5866948459818959
batch 6000 loss: 0.5511749164620414
batch 7000 loss: 0.5276467088288628
batch 8000 loss: 0.4859999405753333
batch 9000 loss: 0.47852722579229157
batch 10000 loss: 0.4600267232111655
batch 11000 loss: 0.45475304189126475
batch 12000 loss: 0.4281130772959441
batch 13000 loss: 0.42768952836247626
batch 14000 loss: 0.41485296625754564
batch 15000 loss: 0.3940568761495815
LOSS train 0.3940568761495815 valid 0.3956470787525177
EPOCH 2:
batch 1000 loss: 0.38715169560920915
batch 2000 loss: 0.3838005330477026
batch 3000 loss: 0.3631052174879587
batch 4000 loss: 0.37617891634977424
batch 5000 loss: 0.39103491323449996
batch 6000 loss: 0.36566069163393694
batch 7000 loss: 0.3667216809855308
batch 8000 loss: 0.35283140427662874
batch 9000 loss: 0.3527607740295061
batch 10000 loss: 0.34925047580266255
batch 11000 loss: 0.37366600868930255
batch 12000 loss: 0.34449519611988216
batch 13000 loss: 0.3459661583202251
batch 14000 loss: 0.33271378462779105
batch 15000 loss: 0.3584383065847287
LOSS train 0.3584383065847287 valid 0.3456021845340729
EPOCH 3:
batch 1000 loss: 0.3251506279369933
batch 2000 loss: 0.3095294443309249
batch 3000 loss: 0.34288687394879525
batch 4000 loss: 0.32768870879523454
batch 5000 loss: 0.32413441178604263
batch 6000 loss: 0.32304036018396437
batch 7000 loss: 0.3266114277824272
batch 8000 loss: 0.3210621049029942
batch 9000 loss: 0.3104973297845281
batch 10000 loss: 0.30057946301011546
batch 11000 loss: 0.31575276248247247
batch 12000 loss: 0.3107170149516714
batch 13000 loss: 0.30591063810196645
batch 14000 loss: 0.3338698188686976
batch 15000 loss: 0.3056864722693572
LOSS train 0.3056864722693572 valid 0.33726003766059875
EPOCH 4:
batch 1000 loss: 0.3011234230146365
batch 2000 loss: 0.2842651428124882
batch 3000 loss: 0.29927633204415727
batch 4000 loss: 0.30857997825050554
batch 5000 loss: 0.2945819073669918
batch 6000 loss: 0.29021309380311866
batch 7000 loss: 0.2970107369068719
batch 8000 loss: 0.306313495246055
batch 9000 loss: 0.30915809172308945
batch 10000 loss: 0.27482260235446304
batch 11000 loss: 0.2918123505726398
batch 12000 loss: 0.30168602007618756
batch 13000 loss: 0.28233931581465005
batch 14000 loss: 0.2691737981650222
batch 15000 loss: 0.2885731645249907
LOSS train 0.2885731645249907 valid 0.31354424357414246
EPOCH 5:
batch 1000 loss: 0.272097287611592
batch 2000 loss: 0.2746666179396416
batch 3000 loss: 0.2689908763434651
batch 4000 loss: 0.26544726403536467
batch 5000 loss: 0.27658389077588436
batch 6000 loss: 0.27193892943436365
batch 7000 loss: 0.26244054821689494
batch 8000 loss: 0.30141668145651784
batch 9000 loss: 0.28003143241122236
batch 10000 loss: 0.2728267575435306
batch 11000 loss: 0.264361325072352
batch 12000 loss: 0.2864540826071679
batch 13000 loss: 0.2513803336042256
batch 14000 loss: 0.28345379100979246
batch 15000 loss: 0.26518901752027796
LOSS train 0.26518901752027796 valid 0.3259730935096741
To load a saved version of the model:
saved_model = GarmentClassifier()
saved_model.load_state_dict(torch.load(PATH))
Once you’ve loaded the model, it’s ready for whatever you need it for - more training, inference, or analysis.
Note that if your model has constructor parameters that affect model structure, you’ll need to provide them and configure the model identically to the state in which it was saved.
Other Resources#
Docs on the data utilities, including Dataset and DataLoader, at pytorch.org
A note on the use of pinned memory for GPU training
Documentation on the datasets available in TorchVision, TorchText, and TorchAudio
Documentation on the loss functions available in PyTorch
Documentation on the torch.optim package, which includes optimizers and related tools, such as learning rate scheduling
A detailed tutorial on saving and loading models
The Tutorials section of pytorch.org contains tutorials on a broad variety of training tasks, including classification in different domains, generative adversarial networks, reinforcement learning, and more
Total running time of the script: (3 minutes 5.277 seconds)