Rate this Page

Introduction || Tensors || Autograd || Building Models || TensorBoard Support || Training Models || Model Understanding

Training with PyTorch#

Created On: Nov 30, 2021 | Last Updated: May 31, 2023 | Last Verified: Nov 05, 2024

Follow along with the video below or on youtube.

Introduction#

In past videos, we’ve discussed and demonstrated:

  • Building models with the neural network layers and functions of the torch.nn module

  • The mechanics of automated gradient computation, which is central to gradient-based model training

  • Using TensorBoard to visualize training progress and other activities

In this video, we’ll be adding some new tools to your inventory:

  • We’ll get familiar with the dataset and dataloader abstractions, and how they ease the process of feeding data to your model during a training loop

  • We’ll discuss specific loss functions and when to use them

  • We’ll look at PyTorch optimizers, which implement algorithms to adjust model weights based on the outcome of a loss function

Finally, we’ll pull all of these together and see a full PyTorch training loop in action.

Dataset and DataLoader#

The Dataset and DataLoader classes encapsulate the process of pulling your data from storage and exposing it to your training loop in batches.

The Dataset is responsible for accessing and processing single instances of data.

The DataLoader pulls instances of data from the Dataset (either automatically or with a sampler that you define), collects them in batches, and returns them for consumption by your training loop. The DataLoader works with all kinds of datasets, regardless of the type of data they contain.

For this tutorial, we’ll be using the Fashion-MNIST dataset provided by TorchVision. We use torchvision.transforms.Normalize() to zero-center and normalize the distribution of the image tile content, and download both training and validation data splits.

import torch
import torchvision
import torchvision.transforms as transforms

# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime


transform = transforms.Compose(
    [transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))])

# Create datasets for training & validation, download if necessary
training_set = torchvision.datasets.FashionMNIST('./data', train=True, transform=transform, download=True)
validation_set = torchvision.datasets.FashionMNIST('./data', train=False, transform=transform, download=True)

# Create data loaders for our datasets; shuffle for training, not for validation
training_loader = torch.utils.data.DataLoader(training_set, batch_size=4, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=4, shuffle=False)

# Class labels
classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
        'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')

# Report split sizes
print('Training set has {} instances'.format(len(training_set)))
print('Validation set has {} instances'.format(len(validation_set)))
  0%|          | 0.00/26.4M [00:00<?, ?B/s]
  0%|          | 65.5k/26.4M [00:00<01:12, 362kB/s]
  1%|          | 229k/26.4M [00:00<00:38, 681kB/s]
  3%|▎         | 918k/26.4M [00:00<00:12, 2.10MB/s]
 14%|█▍        | 3.67M/26.4M [00:00<00:03, 7.24MB/s]
 35%|███▍      | 9.18M/26.4M [00:00<00:01, 15.6MB/s]
 57%|█████▋    | 14.9M/26.4M [00:01<00:00, 21.1MB/s]
 79%|███████▉  | 21.0M/26.4M [00:01<00:00, 25.0MB/s]
100%|██████████| 26.4M/26.4M [00:01<00:00, 19.3MB/s]

  0%|          | 0.00/29.5k [00:00<?, ?B/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 327kB/s]

  0%|          | 0.00/4.42M [00:00<?, ?B/s]
  1%|▏         | 65.5k/4.42M [00:00<00:12, 363kB/s]
  5%|▌         | 229k/4.42M [00:00<00:06, 682kB/s]
 21%|██        | 918k/4.42M [00:00<00:01, 2.11MB/s]
 83%|████████▎ | 3.67M/4.42M [00:00<00:00, 7.29MB/s]
100%|██████████| 4.42M/4.42M [00:00<00:00, 6.10MB/s]

  0%|          | 0.00/5.15k [00:00<?, ?B/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 53.1MB/s]
Training set has 60000 instances
Validation set has 10000 instances

As always, let’s visualize the data as a sanity check:

import matplotlib.pyplot as plt
import numpy as np

# Helper function for inline image display
def matplotlib_imshow(img, one_channel=False):
    if one_channel:
        img = img.mean(dim=0)
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    if one_channel:
        plt.imshow(npimg, cmap="Greys")
    else:
        plt.imshow(np.transpose(npimg, (1, 2, 0)))

dataiter = iter(training_loader)
images, labels = next(dataiter)

# Create a grid from the images and show them
img_grid = torchvision.utils.make_grid(images)
matplotlib_imshow(img_grid, one_channel=True)
print('  '.join(classes[labels[j]] for j in range(4)))
trainingyt
Bag  Trouser  Ankle Boot  T-shirt/top

The Model#

The model we’ll use in this example is a variant of LeNet-5 - it should be familiar if you’ve watched the previous videos in this series.

import torch.nn as nn
import torch.nn.functional as F

# PyTorch models inherit from torch.nn.Module
class GarmentClassifier(nn.Module):
    def __init__(self):
        super(GarmentClassifier, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


model = GarmentClassifier()

Loss Function#

For this example, we’ll be using a cross-entropy loss. For demonstration purposes, we’ll create batches of dummy output and label values, run them through the loss function, and examine the result.

loss_fn = torch.nn.CrossEntropyLoss()

# NB: Loss functions expect data in batches, so we're creating batches of 4
# Represents the model's confidence in each of the 10 classes for a given input
dummy_outputs = torch.rand(4, 10)
# Represents the correct class among the 10 being tested
dummy_labels = torch.tensor([1, 5, 3, 7])

print(dummy_outputs)
print(dummy_labels)

loss = loss_fn(dummy_outputs, dummy_labels)
print('Total loss for this batch: {}'.format(loss.item()))
tensor([[0.7018, 0.7917, 0.0073, 0.7774, 0.4839, 0.4857, 0.8349, 0.1530, 0.6112,
         0.0658],
        [0.9543, 0.9226, 0.6552, 0.5075, 0.7699, 0.1634, 0.5607, 0.2613, 0.5267,
         0.8919],
        [0.9881, 0.6844, 0.0896, 0.1475, 0.6026, 0.0152, 0.9625, 0.3566, 0.9050,
         0.1005],
        [0.9493, 0.1302, 0.5508, 0.3036, 0.6374, 0.7022, 0.5669, 0.5916, 0.5880,
         0.0936]])
tensor([1, 5, 3, 7])
Total loss for this batch: 2.4493112564086914

Optimizer#

For this example, we’ll be using simple stochastic gradient descent with momentum.

It can be instructive to try some variations on this optimization scheme:

  • Learning rate determines the size of the steps the optimizer takes. What does a different learning rate do to the your training results, in terms of accuracy and convergence time?

  • Momentum nudges the optimizer in the direction of strongest gradient over multiple steps. What does changing this value do to your results?

  • Try some different optimization algorithms, such as averaged SGD, Adagrad, or Adam. How do your results differ?

# Optimizers specified in the torch.optim package
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

The Training Loop#

Below, we have a function that performs one training epoch. It enumerates data from the DataLoader, and on each pass of the loop does the following:

  • Gets a batch of training data from the DataLoader

  • Zeros the optimizer’s gradients

  • Performs an inference - that is, gets predictions from the model for an input batch

  • Calculates the loss for that set of predictions vs. the labels on the dataset

  • Calculates the backward gradients over the learning weights

  • Tells the optimizer to perform one learning step - that is, adjust the model’s learning weights based on the observed gradients for this batch, according to the optimization algorithm we chose

  • It reports on the loss for every 1000 batches.

  • Finally, it reports the average per-batch loss for the last 1000 batches, for comparison with a validation run

def train_one_epoch(epoch_index, tb_writer):
    running_loss = 0.
    last_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(training_loader):
        # Every data instance is an input + label pair
        inputs, labels = data

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs = model(inputs)

        # Compute the loss and its gradients
        loss = loss_fn(outputs, labels)
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        if i % 1000 == 999:
            last_loss = running_loss / 1000 # loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            tb_x = epoch_index * len(training_loader) + i + 1
            tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            running_loss = 0.

    return last_loss

Per-Epoch Activity#

There are a couple of things we’ll want to do once per epoch:

  • Perform validation by checking our relative loss on a set of data that was not used for training, and report this

  • Save a copy of the model

Here, we’ll do our reporting in TensorBoard. This will require going to the command line to start TensorBoard, and opening it in another browser tab.

# Initializing in a separate cell so we can easily add more epochs to the same run
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp))
epoch_number = 0

EPOCHS = 5

best_vloss = 1_000_000.

for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch_number + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch(epoch_number, writer)


    running_vloss = 0.0
    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    model.eval()

    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        for i, vdata in enumerate(validation_loader):
            vinputs, vlabels = vdata
            voutputs = model(vinputs)
            vloss = loss_fn(voutputs, vlabels)
            running_vloss += vloss

    avg_vloss = running_vloss / (i + 1)
    print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

    # Log the running loss averaged per batch
    # for both training and validation
    writer.add_scalars('Training vs. Validation Loss',
                    { 'Training' : avg_loss, 'Validation' : avg_vloss },
                    epoch_number + 1)
    writer.flush()

    # Track best performance, and save the model's state
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = 'model_{}_{}'.format(timestamp, epoch_number)
        torch.save(model.state_dict(), model_path)

    epoch_number += 1
EPOCH 1:
  batch 1000 loss: 2.0236759397089483
  batch 2000 loss: 0.8911030735690146
  batch 3000 loss: 0.7536301856394857
  batch 4000 loss: 0.658030877770856
  batch 5000 loss: 0.6164135524192825
  batch 6000 loss: 0.5807270030793734
  batch 7000 loss: 0.5723351529641076
  batch 8000 loss: 0.5452388749844395
  batch 9000 loss: 0.5099193582141306
  batch 10000 loss: 0.4765922690718435
  batch 11000 loss: 0.4720146135836258
  batch 12000 loss: 0.4549544306595344
  batch 13000 loss: 0.4466040487524879
  batch 14000 loss: 0.4516491451340262
  batch 15000 loss: 0.4557598184698727
LOSS train 0.4557598184698727 valid 0.4141642451286316
EPOCH 2:
  batch 1000 loss: 0.4006899055049871
  batch 2000 loss: 0.4085228235855466
  batch 3000 loss: 0.40199646192937505
  batch 4000 loss: 0.4148171956757142
  batch 5000 loss: 0.4124251375843305
  batch 6000 loss: 0.38114031595137204
  batch 7000 loss: 0.39647990153954016
  batch 8000 loss: 0.3841982051011291
  batch 9000 loss: 0.36708792457613165
  batch 10000 loss: 0.3764721789613832
  batch 11000 loss: 0.36533259131899104
  batch 12000 loss: 0.3562970738278236
  batch 13000 loss: 0.3741651373065251
  batch 14000 loss: 0.3381829035620322
  batch 15000 loss: 0.34864214242622255
LOSS train 0.34864214242622255 valid 0.36854201555252075
EPOCH 3:
  batch 1000 loss: 0.3369137547020946
  batch 2000 loss: 0.3345266726673581
  batch 3000 loss: 0.3261624272647314
  batch 4000 loss: 0.3466461898198904
  batch 5000 loss: 0.3287164332018183
  batch 6000 loss: 0.35361249687563395
  batch 7000 loss: 0.32343769609354783
  batch 8000 loss: 0.3510915147009364
  batch 9000 loss: 0.35113862523276473
  batch 10000 loss: 0.312277754293129
  batch 11000 loss: 0.2983704081867011
  batch 12000 loss: 0.3264556283282218
  batch 13000 loss: 0.3255236314142239
  batch 14000 loss: 0.3210075806141358
  batch 15000 loss: 0.34260799641016637
LOSS train 0.34260799641016637 valid 0.3608584403991699
EPOCH 4:
  batch 1000 loss: 0.32086861188316834
  batch 2000 loss: 0.30020638122471427
  batch 3000 loss: 0.30590874500290377
  batch 4000 loss: 0.3247147919106355
  batch 5000 loss: 0.3115404506128689
  batch 6000 loss: 0.3012597737033102
  batch 7000 loss: 0.3041017816969488
  batch 8000 loss: 0.29158685235647136
  batch 9000 loss: 0.3117278872688985
  batch 10000 loss: 0.30090872185258194
  batch 11000 loss: 0.29472640040320036
  batch 12000 loss: 0.2986454657123868
  batch 13000 loss: 0.30281234057439727
  batch 14000 loss: 0.3002894707776577
  batch 15000 loss: 0.2882499357431425
LOSS train 0.2882499357431425 valid 0.3285525143146515
EPOCH 5:
  batch 1000 loss: 0.2832996222781221
  batch 2000 loss: 0.2707584700643056
  batch 3000 loss: 0.2740952929779523
  batch 4000 loss: 0.29007086411034105
  batch 5000 loss: 0.28693814228046177
  batch 6000 loss: 0.289208768123608
  batch 7000 loss: 0.2871940038591456
  batch 8000 loss: 0.30641151567161434
  batch 9000 loss: 0.26597099372593946
  batch 10000 loss: 0.28006415582423505
  batch 11000 loss: 0.27144853257268825
  batch 12000 loss: 0.2726468669723745
  batch 13000 loss: 0.28662368021105794
  batch 14000 loss: 0.3073895778544247
  batch 15000 loss: 0.2798848907670472
LOSS train 0.2798848907670472 valid 0.31915977597236633

To load a saved version of the model:

saved_model = GarmentClassifier()
saved_model.load_state_dict(torch.load(PATH))

Once you’ve loaded the model, it’s ready for whatever you need it for - more training, inference, or analysis.

Note that if your model has constructor parameters that affect model structure, you’ll need to provide them and configure the model identically to the state in which it was saved.

Other Resources#

Total running time of the script: (2 minutes 58.082 seconds)