Rate this Page

Introduction || Tensors || Autograd || Building Models || TensorBoard Support || Training Models || Model Understanding

Training with PyTorch#

Created On: Nov 30, 2021 | Last Updated: May 31, 2023 | Last Verified: Nov 05, 2024

Follow along with the video below or on youtube.

Introduction#

In past videos, we’ve discussed and demonstrated:

  • Building models with the neural network layers and functions of the torch.nn module

  • The mechanics of automated gradient computation, which is central to gradient-based model training

  • Using TensorBoard to visualize training progress and other activities

In this video, we’ll be adding some new tools to your inventory:

  • We’ll get familiar with the dataset and dataloader abstractions, and how they ease the process of feeding data to your model during a training loop

  • We’ll discuss specific loss functions and when to use them

  • We’ll look at PyTorch optimizers, which implement algorithms to adjust model weights based on the outcome of a loss function

Finally, we’ll pull all of these together and see a full PyTorch training loop in action.

Dataset and DataLoader#

The Dataset and DataLoader classes encapsulate the process of pulling your data from storage and exposing it to your training loop in batches.

The Dataset is responsible for accessing and processing single instances of data.

The DataLoader pulls instances of data from the Dataset (either automatically or with a sampler that you define), collects them in batches, and returns them for consumption by your training loop. The DataLoader works with all kinds of datasets, regardless of the type of data they contain.

For this tutorial, we’ll be using the Fashion-MNIST dataset provided by TorchVision. We use torchvision.transforms.Normalize() to zero-center and normalize the distribution of the image tile content, and download both training and validation data splits.

import torch
import torchvision
import torchvision.transforms as transforms

# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime


transform = transforms.Compose(
    [transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))])

# Create datasets for training & validation, download if necessary
training_set = torchvision.datasets.FashionMNIST('./data', train=True, transform=transform, download=True)
validation_set = torchvision.datasets.FashionMNIST('./data', train=False, transform=transform, download=True)

# Create data loaders for our datasets; shuffle for training, not for validation
training_loader = torch.utils.data.DataLoader(training_set, batch_size=4, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=4, shuffle=False)

# Class labels
classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
        'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')

# Report split sizes
print('Training set has {} instances'.format(len(training_set)))
print('Validation set has {} instances'.format(len(validation_set)))
  0%|          | 0.00/26.4M [00:00<?, ?B/s]
  0%|          | 65.5k/26.4M [00:00<01:12, 363kB/s]
  1%|          | 229k/26.4M [00:00<00:38, 680kB/s]
  3%|▎         | 918k/26.4M [00:00<00:12, 2.10MB/s]
 14%|█▍        | 3.67M/26.4M [00:00<00:03, 7.25MB/s]
 36%|███▋      | 9.60M/26.4M [00:00<00:01, 16.4MB/s]
 59%|█████▊    | 15.5M/26.4M [00:01<00:00, 21.8MB/s]
 81%|████████  | 21.4M/26.4M [00:01<00:00, 25.3MB/s]
100%|██████████| 26.4M/26.4M [00:01<00:00, 19.3MB/s]

  0%|          | 0.00/29.5k [00:00<?, ?B/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 326kB/s]

  0%|          | 0.00/4.42M [00:00<?, ?B/s]
  1%|▏         | 65.5k/4.42M [00:00<00:12, 360kB/s]
  5%|▌         | 229k/4.42M [00:00<00:06, 678kB/s]
 21%|██        | 918k/4.42M [00:00<00:01, 2.09MB/s]
 83%|████████▎ | 3.67M/4.42M [00:00<00:00, 7.23MB/s]
100%|██████████| 4.42M/4.42M [00:00<00:00, 6.05MB/s]

  0%|          | 0.00/5.15k [00:00<?, ?B/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 58.5MB/s]
Training set has 60000 instances
Validation set has 10000 instances

As always, let’s visualize the data as a sanity check:

import matplotlib.pyplot as plt
import numpy as np

# Helper function for inline image display
def matplotlib_imshow(img, one_channel=False):
    if one_channel:
        img = img.mean(dim=0)
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    if one_channel:
        plt.imshow(npimg, cmap="Greys")
    else:
        plt.imshow(np.transpose(npimg, (1, 2, 0)))

dataiter = iter(training_loader)
images, labels = next(dataiter)

# Create a grid from the images and show them
img_grid = torchvision.utils.make_grid(images)
matplotlib_imshow(img_grid, one_channel=True)
print('  '.join(classes[labels[j]] for j in range(4)))
trainingyt
Ankle Boot  Bag  Sneaker  Sandal

The Model#

The model we’ll use in this example is a variant of LeNet-5 - it should be familiar if you’ve watched the previous videos in this series.

import torch.nn as nn
import torch.nn.functional as F

# PyTorch models inherit from torch.nn.Module
class GarmentClassifier(nn.Module):
    def __init__(self):
        super(GarmentClassifier, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


model = GarmentClassifier()

Loss Function#

For this example, we’ll be using a cross-entropy loss. For demonstration purposes, we’ll create batches of dummy output and label values, run them through the loss function, and examine the result.

loss_fn = torch.nn.CrossEntropyLoss()

# NB: Loss functions expect data in batches, so we're creating batches of 4
# Represents the model's confidence in each of the 10 classes for a given input
dummy_outputs = torch.rand(4, 10)
# Represents the correct class among the 10 being tested
dummy_labels = torch.tensor([1, 5, 3, 7])

print(dummy_outputs)
print(dummy_labels)

loss = loss_fn(dummy_outputs, dummy_labels)
print('Total loss for this batch: {}'.format(loss.item()))
tensor([[0.4641, 0.0389, 0.6353, 0.2915, 0.5869, 0.4198, 0.3411, 0.8779, 0.0469,
         0.3595],
        [0.5560, 0.3625, 0.5806, 0.6623, 0.4791, 0.0035, 0.4888, 0.8230, 0.2335,
         0.0926],
        [0.7999, 0.9600, 0.0674, 0.6058, 0.5561, 0.7260, 0.8686, 0.1500, 0.1339,
         0.0071],
        [0.7697, 0.5382, 0.4136, 0.6995, 0.8319, 0.8123, 0.8703, 0.2118, 0.5920,
         0.9654]])
tensor([1, 5, 3, 7])
Total loss for this batch: 2.620476484298706

Optimizer#

For this example, we’ll be using simple stochastic gradient descent with momentum.

It can be instructive to try some variations on this optimization scheme:

  • Learning rate determines the size of the steps the optimizer takes. What does a different learning rate do to the your training results, in terms of accuracy and convergence time?

  • Momentum nudges the optimizer in the direction of strongest gradient over multiple steps. What does changing this value do to your results?

  • Try some different optimization algorithms, such as averaged SGD, Adagrad, or Adam. How do your results differ?

# Optimizers specified in the torch.optim package
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

The Training Loop#

Below, we have a function that performs one training epoch. It enumerates data from the DataLoader, and on each pass of the loop does the following:

  • Gets a batch of training data from the DataLoader

  • Zeros the optimizer’s gradients

  • Performs an inference - that is, gets predictions from the model for an input batch

  • Calculates the loss for that set of predictions vs. the labels on the dataset

  • Calculates the backward gradients over the learning weights

  • Tells the optimizer to perform one learning step - that is, adjust the model’s learning weights based on the observed gradients for this batch, according to the optimization algorithm we chose

  • It reports on the loss for every 1000 batches.

  • Finally, it reports the average per-batch loss for the last 1000 batches, for comparison with a validation run

def train_one_epoch(epoch_index, tb_writer):
    running_loss = 0.
    last_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(training_loader):
        # Every data instance is an input + label pair
        inputs, labels = data

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs = model(inputs)

        # Compute the loss and its gradients
        loss = loss_fn(outputs, labels)
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        if i % 1000 == 999:
            last_loss = running_loss / 1000 # loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            tb_x = epoch_index * len(training_loader) + i + 1
            tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            running_loss = 0.

    return last_loss

Per-Epoch Activity#

There are a couple of things we’ll want to do once per epoch:

  • Perform validation by checking our relative loss on a set of data that was not used for training, and report this

  • Save a copy of the model

Here, we’ll do our reporting in TensorBoard. This will require going to the command line to start TensorBoard, and opening it in another browser tab.

# Initializing in a separate cell so we can easily add more epochs to the same run
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp))
epoch_number = 0

EPOCHS = 5

best_vloss = 1_000_000.

for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch_number + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch(epoch_number, writer)


    running_vloss = 0.0
    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    model.eval()

    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        for i, vdata in enumerate(validation_loader):
            vinputs, vlabels = vdata
            voutputs = model(vinputs)
            vloss = loss_fn(voutputs, vlabels)
            running_vloss += vloss

    avg_vloss = running_vloss / (i + 1)
    print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

    # Log the running loss averaged per batch
    # for both training and validation
    writer.add_scalars('Training vs. Validation Loss',
                    { 'Training' : avg_loss, 'Validation' : avg_vloss },
                    epoch_number + 1)
    writer.flush()

    # Track best performance, and save the model's state
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = 'model_{}_{}'.format(timestamp, epoch_number)
        torch.save(model.state_dict(), model_path)

    epoch_number += 1
EPOCH 1:
  batch 1000 loss: 1.7622666106820106
  batch 2000 loss: 0.8714374821810051
  batch 3000 loss: 0.724350129507482
  batch 4000 loss: 0.6662558071399107
  batch 5000 loss: 0.6119116357411258
  batch 6000 loss: 0.553089812912047
  batch 7000 loss: 0.5326241768042382
  batch 8000 loss: 0.4874585609018104
  batch 9000 loss: 0.48870028759888373
  batch 10000 loss: 0.48472532849921846
  batch 11000 loss: 0.48028051887848416
  batch 12000 loss: 0.45211584290582685
  batch 13000 loss: 0.44801454728818496
  batch 14000 loss: 0.420492051506415
  batch 15000 loss: 0.41632540825766046
LOSS train 0.41632540825766046 valid 0.4063527584075928
EPOCH 2:
  batch 1000 loss: 0.40240994511323513
  batch 2000 loss: 0.3849028598798905
  batch 3000 loss: 0.3710368653210462
  batch 4000 loss: 0.38153884670557453
  batch 5000 loss: 0.3661992139663489
  batch 6000 loss: 0.3961341581819579
  batch 7000 loss: 0.3771197603978217
  batch 8000 loss: 0.3696336418205465
  batch 9000 loss: 0.36208432747535696
  batch 10000 loss: 0.3933071748224902
  batch 11000 loss: 0.33063539220421806
  batch 12000 loss: 0.36671002670731107
  batch 13000 loss: 0.36178261497832137
  batch 14000 loss: 0.3366501306547725
  batch 15000 loss: 0.35598218267451737
LOSS train 0.35598218267451737 valid 0.3860108256340027
EPOCH 3:
  batch 1000 loss: 0.33172630738231235
  batch 2000 loss: 0.3213144894433062
  batch 3000 loss: 0.32050017252846735
  batch 4000 loss: 0.3218007811126663
  batch 5000 loss: 0.33239176081254845
  batch 6000 loss: 0.32783556545491593
  batch 7000 loss: 0.32615404148991367
  batch 8000 loss: 0.32656801293001625
  batch 9000 loss: 0.33881579642941506
  batch 10000 loss: 0.30717250532310575
  batch 11000 loss: 0.3404195503274241
  batch 12000 loss: 0.32937910320051017
  batch 13000 loss: 0.32198590667068494
  batch 14000 loss: 0.29432328128764856
  batch 15000 loss: 0.3184672921527235
LOSS train 0.3184672921527235 valid 0.33521440625190735
EPOCH 4:
  batch 1000 loss: 0.29819257747248046
  batch 2000 loss: 0.297317513337046
  batch 3000 loss: 0.31178078035463114
  batch 4000 loss: 0.28763192417199024
  batch 5000 loss: 0.3175242606568754
  batch 6000 loss: 0.2932005113907653
  batch 7000 loss: 0.2875741342298788
  batch 8000 loss: 0.2941194512790535
  batch 9000 loss: 0.27920438118909807
  batch 10000 loss: 0.3006947393612354
  batch 11000 loss: 0.3008414485103567
  batch 12000 loss: 0.28630020079159296
  batch 13000 loss: 0.3153582940995402
  batch 14000 loss: 0.3090263041169819
  batch 15000 loss: 0.31693573112281226
LOSS train 0.31693573112281226 valid 0.3166554570198059
EPOCH 5:
  batch 1000 loss: 0.29748897626389226
  batch 2000 loss: 0.2943619244422425
  batch 3000 loss: 0.27057846084232007
  batch 4000 loss: 0.28254740730442063
  batch 5000 loss: 0.276848270490309
  batch 6000 loss: 0.28325180596915017
  batch 7000 loss: 0.2843594888228563
  batch 8000 loss: 0.2787017031487121
  batch 9000 loss: 0.26478999828883754
  batch 10000 loss: 0.26038969929493033
  batch 11000 loss: 0.2667142621600724
  batch 12000 loss: 0.2708077770742966
  batch 13000 loss: 0.2794458594180905
  batch 14000 loss: 0.2823985450617911
  batch 15000 loss: 0.28865199625557036
LOSS train 0.28865199625557036 valid 0.3337945342063904

To load a saved version of the model:

saved_model = GarmentClassifier()
saved_model.load_state_dict(torch.load(PATH))

Once you’ve loaded the model, it’s ready for whatever you need it for - more training, inference, or analysis.

Note that if your model has constructor parameters that affect model structure, you’ll need to provide them and configure the model identically to the state in which it was saved.

Other Resources#

Total running time of the script: (3 minutes 4.517 seconds)