Shortcuts

Introduction || Tensors || Autograd || Building Models || TensorBoard Support || Training Models || Model Understanding

Training with PyTorch

Created On: Nov 30, 2021 | Last Updated: May 31, 2023 | Last Verified: Nov 05, 2024

Follow along with the video below or on youtube.

Introduction

In past videos, we’ve discussed and demonstrated:

  • Building models with the neural network layers and functions of the torch.nn module

  • The mechanics of automated gradient computation, which is central to gradient-based model training

  • Using TensorBoard to visualize training progress and other activities

In this video, we’ll be adding some new tools to your inventory:

  • We’ll get familiar with the dataset and dataloader abstractions, and how they ease the process of feeding data to your model during a training loop

  • We’ll discuss specific loss functions and when to use them

  • We’ll look at PyTorch optimizers, which implement algorithms to adjust model weights based on the outcome of a loss function

Finally, we’ll pull all of these together and see a full PyTorch training loop in action.

Dataset and DataLoader

The Dataset and DataLoader classes encapsulate the process of pulling your data from storage and exposing it to your training loop in batches.

The Dataset is responsible for accessing and processing single instances of data.

The DataLoader pulls instances of data from the Dataset (either automatically or with a sampler that you define), collects them in batches, and returns them for consumption by your training loop. The DataLoader works with all kinds of datasets, regardless of the type of data they contain.

For this tutorial, we’ll be using the Fashion-MNIST dataset provided by TorchVision. We use torchvision.transforms.Normalize() to zero-center and normalize the distribution of the image tile content, and download both training and validation data splits.

import torch
import torchvision
import torchvision.transforms as transforms

# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime


transform = transforms.Compose(
    [transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))])

# Create datasets for training & validation, download if necessary
training_set = torchvision.datasets.FashionMNIST('./data', train=True, transform=transform, download=True)
validation_set = torchvision.datasets.FashionMNIST('./data', train=False, transform=transform, download=True)

# Create data loaders for our datasets; shuffle for training, not for validation
training_loader = torch.utils.data.DataLoader(training_set, batch_size=4, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=4, shuffle=False)

# Class labels
classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
        'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')

# Report split sizes
print('Training set has {} instances'.format(len(training_set)))
print('Validation set has {} instances'.format(len(validation_set)))
  0%|          | 0.00/26.4M [00:00<?, ?B/s]
  0%|          | 65.5k/26.4M [00:00<01:11, 369kB/s]
  1%|          | 229k/26.4M [00:00<00:37, 692kB/s]
  3%|3         | 918k/26.4M [00:00<00:09, 2.66MB/s]
  7%|7         | 1.93M/26.4M [00:00<00:05, 4.14MB/s]
 25%|##5       | 6.65M/26.4M [00:00<00:01, 15.7MB/s]
 38%|###7      | 9.99M/26.4M [00:00<00:00, 17.4MB/s]
 59%|#####9    | 15.7M/26.4M [00:00<00:00, 27.2MB/s]
 72%|#######2  | 19.1M/26.4M [00:01<00:00, 24.6MB/s]
 91%|#########1| 24.1M/26.4M [00:01<00:00, 30.8MB/s]
100%|##########| 26.4M/26.4M [00:01<00:00, 19.6MB/s]

  0%|          | 0.00/29.5k [00:00<?, ?B/s]
100%|##########| 29.5k/29.5k [00:00<00:00, 331kB/s]

  0%|          | 0.00/4.42M [00:00<?, ?B/s]
  1%|1         | 65.5k/4.42M [00:00<00:12, 363kB/s]
  5%|5         | 229k/4.42M [00:00<00:06, 683kB/s]
 21%|##1       | 950k/4.42M [00:00<00:01, 2.19MB/s]
 78%|#######7  | 3.44M/4.42M [00:00<00:00, 8.22MB/s]
100%|##########| 4.42M/4.42M [00:00<00:00, 6.10MB/s]

  0%|          | 0.00/5.15k [00:00<?, ?B/s]
100%|##########| 5.15k/5.15k [00:00<00:00, 59.2MB/s]
Training set has 60000 instances
Validation set has 10000 instances

As always, let’s visualize the data as a sanity check:

import matplotlib.pyplot as plt
import numpy as np

# Helper function for inline image display
def matplotlib_imshow(img, one_channel=False):
    if one_channel:
        img = img.mean(dim=0)
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    if one_channel:
        plt.imshow(npimg, cmap="Greys")
    else:
        plt.imshow(np.transpose(npimg, (1, 2, 0)))

dataiter = iter(training_loader)
images, labels = next(dataiter)

# Create a grid from the images and show them
img_grid = torchvision.utils.make_grid(images)
matplotlib_imshow(img_grid, one_channel=True)
print('  '.join(classes[labels[j]] for j in range(4)))
trainingyt
Trouser  Pullover  Ankle Boot  Trouser

The Model

The model we’ll use in this example is a variant of LeNet-5 - it should be familiar if you’ve watched the previous videos in this series.

import torch.nn as nn
import torch.nn.functional as F

# PyTorch models inherit from torch.nn.Module
class GarmentClassifier(nn.Module):
    def __init__(self):
        super(GarmentClassifier, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


model = GarmentClassifier()

Loss Function

For this example, we’ll be using a cross-entropy loss. For demonstration purposes, we’ll create batches of dummy output and label values, run them through the loss function, and examine the result.

loss_fn = torch.nn.CrossEntropyLoss()

# NB: Loss functions expect data in batches, so we're creating batches of 4
# Represents the model's confidence in each of the 10 classes for a given input
dummy_outputs = torch.rand(4, 10)
# Represents the correct class among the 10 being tested
dummy_labels = torch.tensor([1, 5, 3, 7])

print(dummy_outputs)
print(dummy_labels)

loss = loss_fn(dummy_outputs, dummy_labels)
print('Total loss for this batch: {}'.format(loss.item()))
tensor([[0.6756, 0.7980, 0.8317, 0.7603, 0.3541, 0.6252, 0.6101, 0.4059, 0.7276,
         0.9457],
        [0.9579, 0.1918, 0.3763, 0.3049, 0.1708, 0.6332, 0.3032, 0.8479, 0.9219,
         0.8660],
        [0.9111, 0.8363, 0.1696, 0.0574, 0.0835, 0.4695, 0.7398, 0.9403, 0.0584,
         0.9106],
        [0.4170, 0.5662, 0.7512, 0.1372, 0.3623, 0.6905, 0.2663, 0.7750, 0.8752,
         0.3158]])
tensor([1, 5, 3, 7])
Total loss for this batch: 2.341409683227539

Optimizer

For this example, we’ll be using simple stochastic gradient descent with momentum.

It can be instructive to try some variations on this optimization scheme:

  • Learning rate determines the size of the steps the optimizer takes. What does a different learning rate do to the your training results, in terms of accuracy and convergence time?

  • Momentum nudges the optimizer in the direction of strongest gradient over multiple steps. What does changing this value do to your results?

  • Try some different optimization algorithms, such as averaged SGD, Adagrad, or Adam. How do your results differ?

# Optimizers specified in the torch.optim package
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

The Training Loop

Below, we have a function that performs one training epoch. It enumerates data from the DataLoader, and on each pass of the loop does the following:

  • Gets a batch of training data from the DataLoader

  • Zeros the optimizer’s gradients

  • Performs an inference - that is, gets predictions from the model for an input batch

  • Calculates the loss for that set of predictions vs. the labels on the dataset

  • Calculates the backward gradients over the learning weights

  • Tells the optimizer to perform one learning step - that is, adjust the model’s learning weights based on the observed gradients for this batch, according to the optimization algorithm we chose

  • It reports on the loss for every 1000 batches.

  • Finally, it reports the average per-batch loss for the last 1000 batches, for comparison with a validation run

def train_one_epoch(epoch_index, tb_writer):
    running_loss = 0.
    last_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(training_loader):
        # Every data instance is an input + label pair
        inputs, labels = data

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs = model(inputs)

        # Compute the loss and its gradients
        loss = loss_fn(outputs, labels)
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        if i % 1000 == 999:
            last_loss = running_loss / 1000 # loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            tb_x = epoch_index * len(training_loader) + i + 1
            tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            running_loss = 0.

    return last_loss

Per-Epoch Activity

There are a couple of things we’ll want to do once per epoch:

  • Perform validation by checking our relative loss on a set of data that was not used for training, and report this

  • Save a copy of the model

Here, we’ll do our reporting in TensorBoard. This will require going to the command line to start TensorBoard, and opening it in another browser tab.

# Initializing in a separate cell so we can easily add more epochs to the same run
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp))
epoch_number = 0

EPOCHS = 5

best_vloss = 1_000_000.

for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch_number + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch(epoch_number, writer)


    running_vloss = 0.0
    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    model.eval()

    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        for i, vdata in enumerate(validation_loader):
            vinputs, vlabels = vdata
            voutputs = model(vinputs)
            vloss = loss_fn(voutputs, vlabels)
            running_vloss += vloss

    avg_vloss = running_vloss / (i + 1)
    print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

    # Log the running loss averaged per batch
    # for both training and validation
    writer.add_scalars('Training vs. Validation Loss',
                    { 'Training' : avg_loss, 'Validation' : avg_vloss },
                    epoch_number + 1)
    writer.flush()

    # Track best performance, and save the model's state
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = 'model_{}_{}'.format(timestamp, epoch_number)
        torch.save(model.state_dict(), model_path)

    epoch_number += 1
EPOCH 1:
  batch 1000 loss: 1.7536216277182102
  batch 2000 loss: 0.8190080421678722
  batch 3000 loss: 0.700689250588417
  batch 4000 loss: 0.6367347843456082
  batch 5000 loss: 0.5702252120454796
  batch 6000 loss: 0.5832496561647859
  batch 7000 loss: 0.5451787600663957
  batch 8000 loss: 0.5318371822668705
  batch 9000 loss: 0.4998274987768382
  batch 10000 loss: 0.4861003523846157
  batch 11000 loss: 0.462952514749486
  batch 12000 loss: 0.45482242486847096
  batch 13000 loss: 0.4385815391012002
  batch 14000 loss: 0.44237795609093156
  batch 15000 loss: 0.42834632108511866
LOSS train 0.42834632108511866 valid 0.4279777407646179
EPOCH 2:
  batch 1000 loss: 0.3912717408046883
  batch 2000 loss: 0.3908861642718548
  batch 3000 loss: 0.40499804124957883
  batch 4000 loss: 0.40315820360236104
  batch 5000 loss: 0.4001170339985983
  batch 6000 loss: 0.3879377021771506
  batch 7000 loss: 0.38474963369296167
  batch 8000 loss: 0.3890397538637335
  batch 9000 loss: 0.3820151900724741
  batch 10000 loss: 0.38206415582021874
  batch 11000 loss: 0.3796248573293851
  batch 12000 loss: 0.3752163356099627
  batch 13000 loss: 0.3764746580963256
  batch 14000 loss: 0.34548993636778325
  batch 15000 loss: 0.3631109071790997
LOSS train 0.3631109071790997 valid 0.36702868342399597
EPOCH 3:
  batch 1000 loss: 0.34729626949541853
  batch 2000 loss: 0.333685194198406
  batch 3000 loss: 0.3313898945654946
  batch 4000 loss: 0.32808600346938693
  batch 5000 loss: 0.3326914912162465
  batch 6000 loss: 0.32520363175662353
  batch 7000 loss: 0.32000095026698544
  batch 8000 loss: 0.35664220971794564
  batch 9000 loss: 0.3497494743829302
  batch 10000 loss: 0.3228544230395928
  batch 11000 loss: 0.33513641009817363
  batch 12000 loss: 0.33927128310580157
  batch 13000 loss: 0.3232689249026153
  batch 14000 loss: 0.32395145892064464
  batch 15000 loss: 0.3392171728776739
LOSS train 0.3392171728776739 valid 0.3459877073764801
EPOCH 4:
  batch 1000 loss: 0.30242849304602715
  batch 2000 loss: 0.30626966073081713
  batch 3000 loss: 0.3123660156664555
  batch 4000 loss: 0.320031703075947
  batch 5000 loss: 0.31932821953319945
  batch 6000 loss: 0.31352192115694927
  batch 7000 loss: 0.29349351018325703
  batch 8000 loss: 0.2909870434643235
  batch 9000 loss: 0.2875903835784557
  batch 10000 loss: 0.3160401984210294
  batch 11000 loss: 0.28181703339913655
  batch 12000 loss: 0.311081383400473
  batch 13000 loss: 0.29503489564753543
  batch 14000 loss: 0.31576050062226113
  batch 15000 loss: 0.29751126185648175
LOSS train 0.29751126185648175 valid 0.32522276043891907
EPOCH 5:
  batch 1000 loss: 0.2751962016681282
  batch 2000 loss: 0.2999601041037386
  batch 3000 loss: 0.266331037209884
  batch 4000 loss: 0.29807080660707286
  batch 5000 loss: 0.29425640619049953
  batch 6000 loss: 0.3038723004531712
  batch 7000 loss: 0.2771170318634138
  batch 8000 loss: 0.28474529803660514
  batch 9000 loss: 0.29053163012857475
  batch 10000 loss: 0.2796333142037802
  batch 11000 loss: 0.2834179482117579
  batch 12000 loss: 0.2845251598034156
  batch 13000 loss: 0.28230239558218456
  batch 14000 loss: 0.3021180066906745
  batch 15000 loss: 0.2781171326651893
LOSS train 0.2781171326651893 valid 0.32490113377571106

To load a saved version of the model:

saved_model = GarmentClassifier()
saved_model.load_state_dict(torch.load(PATH))

Once you’ve loaded the model, it’s ready for whatever you need it for - more training, inference, or analysis.

Note that if your model has constructor parameters that affect model structure, you’ll need to provide them and configure the model identically to the state in which it was saved.

Other Resources

Total running time of the script: ( 3 minutes 2.744 seconds)

Gallery generated by Sphinx-Gallery

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources