Note
Go to the end to download the full example code.
Introduction || Tensors || Autograd || Building Models || TensorBoard Support || Training Models || Model Understanding
Training with PyTorch#
Created On: Nov 30, 2021 | Last Updated: May 06, 2026 | Last Verified: Nov 05, 2024
Follow along with the video below or on youtube.
Introduction#
In past videos, we’ve discussed and demonstrated:
Building models with the neural network layers and functions of the torch.nn module
The mechanics of automated gradient computation, which is central to gradient-based model training
Using TensorBoard to visualize training progress and other activities
In this video, we’ll be adding some new tools to your inventory:
We’ll get familiar with the dataset and dataloader abstractions, and how they ease the process of feeding data to your model during a training loop
We’ll discuss specific loss functions and when to use them
We’ll look at PyTorch optimizers, which implement algorithms to adjust model weights based on the outcome of a loss function
Finally, we’ll pull all of these together and see a full PyTorch training loop in action.
Dataset and DataLoader#
The Dataset and DataLoader classes encapsulate the process of
pulling your data from storage and exposing it to your training loop in
batches.
The Dataset is responsible for accessing and processing single
instances of data.
The DataLoader pulls instances of data from the Dataset (either
automatically or with a sampler that you define), collects them in
batches, and returns them for consumption by your training loop. The
DataLoader works with all kinds of datasets, regardless of the type
of data they contain.
For this tutorial, we’ll be using the Fashion-MNIST dataset provided by
TorchVision. We use torchvision.transforms.v2.Normalize() to
zero-center and normalize the distribution of the image tile content,
and download both training and validation data splits.
import torch
import torchvision
from torchvision.transforms import v2
# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
transform = v2.Compose([
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize((0.5,), (0.5,))
])
# Create datasets for training & validation, download if necessary
training_set = torchvision.datasets.FashionMNIST('./data', train=True, transform=transform, download=True)
validation_set = torchvision.datasets.FashionMNIST('./data', train=False, transform=transform, download=True)
# Create data loaders for our datasets; shuffle for training, not for validation
training_loader = torch.utils.data.DataLoader(training_set, batch_size=4, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=4, shuffle=False)
# Class labels
classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')
# Report split sizes
print(f'Training set has {len(training_set)} instances')
print(f'Validation set has {len(validation_set)} instances')
0%| | 0.00/26.4M [00:00<?, ?B/s]
0%| | 65.5k/26.4M [00:00<01:11, 371kB/s]
1%| | 229k/26.4M [00:00<00:37, 696kB/s]
3%|▎ | 885k/26.4M [00:00<00:12, 2.06MB/s]
14%|█▎ | 3.57M/26.4M [00:00<00:03, 7.22MB/s]
34%|███▍ | 8.98M/26.4M [00:00<00:01, 15.6MB/s]
56%|█████▌ | 14.7M/26.4M [00:01<00:00, 21.3MB/s]
78%|███████▊ | 20.5M/26.4M [00:01<00:00, 25.1MB/s]
100%|█████████▉| 26.4M/26.4M [00:01<00:00, 27.5MB/s]
100%|██████████| 26.4M/26.4M [00:01<00:00, 18.6MB/s]
0%| | 0.00/29.5k [00:00<?, ?B/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 339kB/s]
0%| | 0.00/4.42M [00:00<?, ?B/s]
1%|▏ | 65.5k/4.42M [00:00<00:11, 371kB/s]
5%|▌ | 229k/4.42M [00:00<00:06, 697kB/s]
21%|██ | 918k/4.42M [00:00<00:01, 2.15MB/s]
79%|███████▉ | 3.51M/4.42M [00:00<00:00, 7.07MB/s]
100%|██████████| 4.42M/4.42M [00:00<00:00, 6.22MB/s]
0%| | 0.00/5.15k [00:00<?, ?B/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 45.6MB/s]
Training set has 60000 instances
Validation set has 10000 instances
As always, let’s visualize the data as a sanity check:
import matplotlib.pyplot as plt
import numpy as np
# Helper function for inline image display
def matplotlib_imshow(img, one_channel=False):
if one_channel:
img = img.mean(dim=0)
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
if one_channel:
plt.imshow(npimg, cmap="Greys")
else:
plt.imshow(np.transpose(npimg, (1, 2, 0)))
dataiter = iter(training_loader)
images, labels = next(dataiter)
# Create a grid from the images and show them
img_grid = torchvision.utils.make_grid(images)
matplotlib_imshow(img_grid, one_channel=True)
print(' '.join(classes[labels[j]] for j in range(4)))

Dress Coat Dress Ankle Boot
The Model#
The model we’ll use in this example is a variant of LeNet-5 - it should be familiar if you’ve watched the previous videos in this series.
import torch.nn as nn
import torch.nn.functional as F
# PyTorch models inherit from torch.nn.Module
class GarmentClassifier(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 4 * 4, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 4 * 4)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
model = GarmentClassifier()
Loss Function#
For this example, we’ll be using a cross-entropy loss. For demonstration purposes, we’ll create batches of dummy output and label values, run them through the loss function, and examine the result.
loss_fn = torch.nn.CrossEntropyLoss()
# NB: Loss functions expect data in batches, so we're creating batches of 4
# Represents the model's confidence in each of the 10 classes for a given input
dummy_outputs = torch.rand(4, 10)
# Represents the correct class among the 10 being tested
dummy_labels = torch.tensor([1, 5, 3, 7])
print(dummy_outputs)
print(dummy_labels)
loss = loss_fn(dummy_outputs, dummy_labels)
print(f'Total loss for this batch: {loss.item()}')
tensor([[0.7478, 0.0464, 0.7828, 0.5873, 0.3767, 0.0260, 0.4662, 0.5116, 0.8879,
0.9898],
[0.9356, 0.9325, 0.5493, 0.5131, 0.6096, 0.9769, 0.7108, 0.2812, 0.0449,
0.4855],
[0.6789, 0.0740, 0.0469, 0.7979, 0.0361, 0.0048, 0.5668, 0.8868, 0.9051,
0.2568],
[0.0541, 0.0797, 0.9294, 0.0690, 0.8948, 0.5610, 0.5696, 0.4554, 0.5694,
0.1948]])
tensor([1, 5, 3, 7])
Total loss for this batch: 2.285301923751831
Optimizer#
For this example, we’ll be using simple stochastic gradient descent with momentum.
It can be instructive to try some variations on this optimization scheme:
Learning rate determines the size of the steps the optimizer takes. What does a different learning rate do to the your training results, in terms of accuracy and convergence time?
Momentum nudges the optimizer in the direction of strongest gradient over multiple steps. What does changing this value do to your results?
Try some different optimization algorithms, such as averaged SGD, Adagrad, or Adam. How do your results differ?
# Optimizers specified in the torch.optim package
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
The Training Loop#
Below, we have a function that performs one training epoch. It enumerates data from the DataLoader, and on each pass of the loop does the following:
Gets a batch of training data from the DataLoader
Zeros the optimizer’s gradients
Performs an inference - that is, gets predictions from the model for an input batch
Calculates the loss for that set of predictions vs. the labels on the dataset
Calculates the backward gradients over the learning weights
Tells the optimizer to perform one learning step - that is, adjust the model’s learning weights based on the observed gradients for this batch, according to the optimization algorithm we chose
It reports on the loss for every 1000 batches.
Finally, it reports the average per-batch loss for the last 1000 batches, for comparison with a validation run
def train_one_epoch(epoch_index, tb_writer):
running_loss = 0.
last_loss = 0.
# Here, we use enumerate(training_loader) instead of
# iter(training_loader) so that we can track the batch
# index and do some intra-epoch reporting
for i, data in enumerate(training_loader):
# Every data instance is an input + label pair
inputs, labels = data
# Zero your gradients for every batch!
optimizer.zero_grad()
# Make predictions for this batch
outputs = model(inputs)
# Compute the loss and its gradients
loss = loss_fn(outputs, labels)
loss.backward()
# Adjust learning weights
optimizer.step()
# Gather data and report
running_loss += loss.item()
if i % 1000 == 999:
last_loss = running_loss / 1000 # loss per batch
print(f' batch {i + 1} loss: {last_loss}')
tb_x = epoch_index * len(training_loader) + i + 1
tb_writer.add_scalar('Loss/train', last_loss, tb_x)
running_loss = 0.
return last_loss
Per-Epoch Activity#
There are a couple of things we’ll want to do once per epoch:
Perform validation by checking our relative loss on a set of data that was not used for training, and report this
Save a copy of the model
Here, we’ll do our reporting in TensorBoard. This will require going to the command line to start TensorBoard, and opening it in another browser tab.
# Initializing in a separate cell so we can easily add more epochs to the same run
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter(f'runs/fashion_trainer_{timestamp}')
epoch_number = 0
EPOCHS = 5
best_vloss = 1_000_000.
for epoch in range(EPOCHS):
print(f'EPOCH {epoch_number + 1}:')
# Make sure gradient tracking is on, and do a pass over the data
model.train(True)
avg_loss = train_one_epoch(epoch_number, writer)
running_vloss = 0.0
# Set the model to evaluation mode, disabling dropout and using population
# statistics for batch normalization.
model.eval()
# Disable gradient computation and reduce memory consumption.
with torch.no_grad():
for i, vdata in enumerate(validation_loader):
vinputs, vlabels = vdata
voutputs = model(vinputs)
vloss = loss_fn(voutputs, vlabels)
running_vloss += vloss
avg_vloss = running_vloss / (i + 1)
print(f'LOSS train {avg_loss} valid {avg_vloss}')
# Log the running loss averaged per batch
# for both training and validation
writer.add_scalars('Training vs. Validation Loss',
{ 'Training' : avg_loss, 'Validation' : avg_vloss },
epoch_number + 1)
writer.flush()
# Track best performance, and save the model's state
if avg_vloss < best_vloss:
best_vloss = avg_vloss
model_path = f'model_{timestamp}_{epoch_number}'
torch.save(model.state_dict(), model_path)
epoch_number += 1
EPOCH 1:
batch 1000 loss: 1.9201134222000837
batch 2000 loss: 0.8936810662262141
batch 3000 loss: 0.7354398246854543
batch 4000 loss: 0.6658116950076074
batch 5000 loss: 0.5987123535331339
batch 6000 loss: 0.5512049491702347
batch 7000 loss: 0.5042953580608591
batch 8000 loss: 0.5141121485423064
batch 9000 loss: 0.4794655260576401
batch 10000 loss: 0.4888356875733589
batch 11000 loss: 0.4472574892928824
batch 12000 loss: 0.4198870799881988
batch 13000 loss: 0.438963660903275
batch 14000 loss: 0.4083680252630147
batch 15000 loss: 0.40982389317639173
LOSS train 0.40982389317639173 valid 0.45463138818740845
EPOCH 2:
batch 1000 loss: 0.39642382254370023
batch 2000 loss: 0.3667202236213489
batch 3000 loss: 0.37321549149573546
batch 4000 loss: 0.38423057076369876
batch 5000 loss: 0.36775684920628554
batch 6000 loss: 0.3509713848091196
batch 7000 loss: 0.3646314028152265
batch 8000 loss: 0.36193755505944136
batch 9000 loss: 0.3793991212390247
batch 10000 loss: 0.3318117807184099
batch 11000 loss: 0.35165498141641727
batch 12000 loss: 0.3540895422938047
batch 13000 loss: 0.3533427465384593
batch 14000 loss: 0.35021086086906145
batch 15000 loss: 0.34439150294536375
LOSS train 0.34439150294536375 valid 0.3544321060180664
EPOCH 3:
batch 1000 loss: 0.3468644726048951
batch 2000 loss: 0.3283032576391997
batch 3000 loss: 0.32121480108232936
batch 4000 loss: 0.31656404185999415
batch 5000 loss: 0.3039547623791077
batch 6000 loss: 0.3008622578030918
batch 7000 loss: 0.3247767886732763
batch 8000 loss: 0.3157894377858902
batch 9000 loss: 0.2975363698947622
batch 10000 loss: 0.323901835494431
batch 11000 loss: 0.3261104478448105
batch 12000 loss: 0.28065282518885215
batch 13000 loss: 0.3095580764692604
batch 14000 loss: 0.32428999791605745
batch 15000 loss: 0.27786616887440324
LOSS train 0.27786616887440324 valid 0.33936741948127747
EPOCH 4:
batch 1000 loss: 0.2918487318518019
batch 2000 loss: 0.28185603537365206
batch 3000 loss: 0.27675634648789127
batch 4000 loss: 0.29913072428875603
batch 5000 loss: 0.29988044832734884
batch 6000 loss: 0.28877245042166033
batch 7000 loss: 0.28200169986383117
batch 8000 loss: 0.276411639795464
batch 9000 loss: 0.2973476366762334
batch 10000 loss: 0.29068108931634923
batch 11000 loss: 0.2935327341644115
batch 12000 loss: 0.27733401207738645
batch 13000 loss: 0.288186743367849
batch 14000 loss: 0.29263715256195066
batch 15000 loss: 0.2873050424955509
LOSS train 0.2873050424955509 valid 0.3149283826351166
EPOCH 5:
batch 1000 loss: 0.266026705986711
batch 2000 loss: 0.27676754854721364
batch 3000 loss: 0.2730195448407067
batch 4000 loss: 0.2745680915591538
batch 5000 loss: 0.2563295667306675
batch 6000 loss: 0.2687258635236958
batch 7000 loss: 0.2604165060398591
batch 8000 loss: 0.2469878823280683
batch 9000 loss: 0.28684585436138876
batch 10000 loss: 0.26537751772847334
batch 11000 loss: 0.26665223997646537
batch 12000 loss: 0.2774165414618251
batch 13000 loss: 0.2757912719691167
batch 14000 loss: 0.26843816110451735
batch 15000 loss: 0.2665404275908149
LOSS train 0.2665404275908149 valid 0.2919544577598572
To load a saved version of the model:
saved_model = GarmentClassifier()
saved_model.load_state_dict(torch.load(PATH))
Once you’ve loaded the model, it’s ready for whatever you need it for - more training, inference, or analysis.
Note that if your model has constructor parameters that affect model structure, you’ll need to provide them and configure the model identically to the state in which it was saved.
Other Resources#
Docs on the data utilities, including Dataset and DataLoader, at pytorch.org
A note on the use of pinned memory for GPU training
Documentation on the datasets available in TorchVision, TorchText, and TorchAudio
Documentation on the loss functions available in PyTorch
Documentation on the torch.optim package, which includes optimizers and related tools, such as learning rate scheduling
A detailed tutorial on saving and loading models
The Tutorials section of pytorch.org contains tutorials on a broad variety of training tasks, including classification in different domains, generative adversarial networks, reinforcement learning, and more
Total running time of the script: (3 minutes 24.110 seconds)