Rate this Page

Introduction || Tensors || Autograd || Building Models || TensorBoard Support || Training Models || Model Understanding

Training with PyTorch#

Created On: Nov 30, 2021 | Last Updated: May 31, 2023 | Last Verified: Nov 05, 2024

Follow along with the video below or on youtube.

Introduction#

In past videos, we’ve discussed and demonstrated:

  • Building models with the neural network layers and functions of the torch.nn module

  • The mechanics of automated gradient computation, which is central to gradient-based model training

  • Using TensorBoard to visualize training progress and other activities

In this video, we’ll be adding some new tools to your inventory:

  • We’ll get familiar with the dataset and dataloader abstractions, and how they ease the process of feeding data to your model during a training loop

  • We’ll discuss specific loss functions and when to use them

  • We’ll look at PyTorch optimizers, which implement algorithms to adjust model weights based on the outcome of a loss function

Finally, we’ll pull all of these together and see a full PyTorch training loop in action.

Dataset and DataLoader#

The Dataset and DataLoader classes encapsulate the process of pulling your data from storage and exposing it to your training loop in batches.

The Dataset is responsible for accessing and processing single instances of data.

The DataLoader pulls instances of data from the Dataset (either automatically or with a sampler that you define), collects them in batches, and returns them for consumption by your training loop. The DataLoader works with all kinds of datasets, regardless of the type of data they contain.

For this tutorial, we’ll be using the Fashion-MNIST dataset provided by TorchVision. We use torchvision.transforms.Normalize() to zero-center and normalize the distribution of the image tile content, and download both training and validation data splits.

import torch
import torchvision
import torchvision.transforms as transforms

# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime


transform = transforms.Compose(
    [transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))])

# Create datasets for training & validation, download if necessary
training_set = torchvision.datasets.FashionMNIST('./data', train=True, transform=transform, download=True)
validation_set = torchvision.datasets.FashionMNIST('./data', train=False, transform=transform, download=True)

# Create data loaders for our datasets; shuffle for training, not for validation
training_loader = torch.utils.data.DataLoader(training_set, batch_size=4, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=4, shuffle=False)

# Class labels
classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
        'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')

# Report split sizes
print('Training set has {} instances'.format(len(training_set)))
print('Validation set has {} instances'.format(len(validation_set)))
  0%|          | 0.00/26.4M [00:00<?, ?B/s]
  0%|          | 65.5k/26.4M [00:00<01:12, 362kB/s]
  1%|          | 229k/26.4M [00:00<00:38, 678kB/s]
  3%|▎         | 918k/26.4M [00:00<00:12, 2.09MB/s]
 14%|█▍        | 3.67M/26.4M [00:00<00:03, 7.23MB/s]
 36%|███▌      | 9.47M/26.4M [00:00<00:01, 16.1MB/s]
 59%|█████▉    | 15.5M/26.4M [00:01<00:00, 21.9MB/s]
 82%|████████▏ | 21.6M/26.4M [00:01<00:00, 25.6MB/s]
100%|██████████| 26.4M/26.4M [00:01<00:00, 19.3MB/s]

  0%|          | 0.00/29.5k [00:00<?, ?B/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 324kB/s]

  0%|          | 0.00/4.42M [00:00<?, ?B/s]
  1%|▏         | 65.5k/4.42M [00:00<00:12, 361kB/s]
  5%|▌         | 229k/4.42M [00:00<00:06, 678kB/s]
 20%|██        | 885k/4.42M [00:00<00:01, 2.01MB/s]
 81%|████████  | 3.57M/4.42M [00:00<00:00, 7.04MB/s]
100%|██████████| 4.42M/4.42M [00:00<00:00, 6.06MB/s]

  0%|          | 0.00/5.15k [00:00<?, ?B/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 67.9MB/s]
Training set has 60000 instances
Validation set has 10000 instances

As always, let’s visualize the data as a sanity check:

import matplotlib.pyplot as plt
import numpy as np

# Helper function for inline image display
def matplotlib_imshow(img, one_channel=False):
    if one_channel:
        img = img.mean(dim=0)
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    if one_channel:
        plt.imshow(npimg, cmap="Greys")
    else:
        plt.imshow(np.transpose(npimg, (1, 2, 0)))

dataiter = iter(training_loader)
images, labels = next(dataiter)

# Create a grid from the images and show them
img_grid = torchvision.utils.make_grid(images)
matplotlib_imshow(img_grid, one_channel=True)
print('  '.join(classes[labels[j]] for j in range(4)))
trainingyt
Bag  Coat  Coat  Coat

The Model#

The model we’ll use in this example is a variant of LeNet-5 - it should be familiar if you’ve watched the previous videos in this series.

import torch.nn as nn
import torch.nn.functional as F

# PyTorch models inherit from torch.nn.Module
class GarmentClassifier(nn.Module):
    def __init__(self):
        super(GarmentClassifier, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


model = GarmentClassifier()

Loss Function#

For this example, we’ll be using a cross-entropy loss. For demonstration purposes, we’ll create batches of dummy output and label values, run them through the loss function, and examine the result.

loss_fn = torch.nn.CrossEntropyLoss()

# NB: Loss functions expect data in batches, so we're creating batches of 4
# Represents the model's confidence in each of the 10 classes for a given input
dummy_outputs = torch.rand(4, 10)
# Represents the correct class among the 10 being tested
dummy_labels = torch.tensor([1, 5, 3, 7])

print(dummy_outputs)
print(dummy_labels)

loss = loss_fn(dummy_outputs, dummy_labels)
print('Total loss for this batch: {}'.format(loss.item()))
tensor([[7.6755e-01, 2.1209e-01, 3.9685e-04, 1.1784e-01, 5.8920e-01, 2.1658e-01,
         5.7869e-02, 6.1950e-01, 9.1382e-01, 5.5316e-01],
        [3.4975e-01, 1.6163e-01, 7.3083e-01, 9.7484e-01, 1.1170e-01, 1.5072e-01,
         9.7744e-01, 9.6523e-01, 5.7972e-01, 9.9445e-01],
        [2.5699e-01, 2.0019e-01, 5.9594e-01, 9.2452e-01, 2.7221e-01, 9.6277e-01,
         4.1827e-03, 8.2591e-01, 8.8440e-01, 6.4541e-01],
        [4.2114e-01, 3.6832e-01, 9.9400e-01, 8.0273e-01, 4.0445e-01, 4.3448e-01,
         4.5120e-01, 4.9203e-01, 3.5417e-01, 1.0004e-01]])
tensor([1, 5, 3, 7])
Total loss for this batch: 2.416367530822754

Optimizer#

For this example, we’ll be using simple stochastic gradient descent with momentum.

It can be instructive to try some variations on this optimization scheme:

  • Learning rate determines the size of the steps the optimizer takes. What does a different learning rate do to the your training results, in terms of accuracy and convergence time?

  • Momentum nudges the optimizer in the direction of strongest gradient over multiple steps. What does changing this value do to your results?

  • Try some different optimization algorithms, such as averaged SGD, Adagrad, or Adam. How do your results differ?

# Optimizers specified in the torch.optim package
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

The Training Loop#

Below, we have a function that performs one training epoch. It enumerates data from the DataLoader, and on each pass of the loop does the following:

  • Gets a batch of training data from the DataLoader

  • Zeros the optimizer’s gradients

  • Performs an inference - that is, gets predictions from the model for an input batch

  • Calculates the loss for that set of predictions vs. the labels on the dataset

  • Calculates the backward gradients over the learning weights

  • Tells the optimizer to perform one learning step - that is, adjust the model’s learning weights based on the observed gradients for this batch, according to the optimization algorithm we chose

  • It reports on the loss for every 1000 batches.

  • Finally, it reports the average per-batch loss for the last 1000 batches, for comparison with a validation run

def train_one_epoch(epoch_index, tb_writer):
    running_loss = 0.
    last_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(training_loader):
        # Every data instance is an input + label pair
        inputs, labels = data

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs = model(inputs)

        # Compute the loss and its gradients
        loss = loss_fn(outputs, labels)
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        if i % 1000 == 999:
            last_loss = running_loss / 1000 # loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            tb_x = epoch_index * len(training_loader) + i + 1
            tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            running_loss = 0.

    return last_loss

Per-Epoch Activity#

There are a couple of things we’ll want to do once per epoch:

  • Perform validation by checking our relative loss on a set of data that was not used for training, and report this

  • Save a copy of the model

Here, we’ll do our reporting in TensorBoard. This will require going to the command line to start TensorBoard, and opening it in another browser tab.

# Initializing in a separate cell so we can easily add more epochs to the same run
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp))
epoch_number = 0

EPOCHS = 5

best_vloss = 1_000_000.

for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch_number + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch(epoch_number, writer)


    running_vloss = 0.0
    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    model.eval()

    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        for i, vdata in enumerate(validation_loader):
            vinputs, vlabels = vdata
            voutputs = model(vinputs)
            vloss = loss_fn(voutputs, vlabels)
            running_vloss += vloss

    avg_vloss = running_vloss / (i + 1)
    print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

    # Log the running loss averaged per batch
    # for both training and validation
    writer.add_scalars('Training vs. Validation Loss',
                    { 'Training' : avg_loss, 'Validation' : avg_vloss },
                    epoch_number + 1)
    writer.flush()

    # Track best performance, and save the model's state
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = 'model_{}_{}'.format(timestamp, epoch_number)
        torch.save(model.state_dict(), model_path)

    epoch_number += 1
EPOCH 1:
  batch 1000 loss: 1.7385720927417279
  batch 2000 loss: 0.832697173718363
  batch 3000 loss: 0.6864131525661796
  batch 4000 loss: 0.6105514444704168
  batch 5000 loss: 0.5850542451408691
  batch 6000 loss: 0.5750407866821624
  batch 7000 loss: 0.5272775211334229
  batch 8000 loss: 0.5274288718159078
  batch 9000 loss: 0.48876330685405994
  batch 10000 loss: 0.475955779616721
  batch 11000 loss: 0.4559065450442431
  batch 12000 loss: 0.4134288992721122
  batch 13000 loss: 0.4540224991024006
  batch 14000 loss: 0.4361313150327769
  batch 15000 loss: 0.4199084949973039
LOSS train 0.4199084949973039 valid 0.41430190205574036
EPOCH 2:
  batch 1000 loss: 0.42189482399937694
  batch 2000 loss: 0.3999643763921631
  batch 3000 loss: 0.40902178106020437
  batch 4000 loss: 0.3967328935181722
  batch 5000 loss: 0.3664240125938086
  batch 6000 loss: 0.40228067599228234
  batch 7000 loss: 0.3820848577066208
  batch 8000 loss: 0.3789055755652953
  batch 9000 loss: 0.37016932864103
  batch 10000 loss: 0.368606689903012
  batch 11000 loss: 0.37986740898923016
  batch 12000 loss: 0.3556470774288173
  batch 13000 loss: 0.33975730662039133
  batch 14000 loss: 0.35321106253226753
  batch 15000 loss: 0.33601367170797314
LOSS train 0.33601367170797314 valid 0.3791325092315674
EPOCH 3:
  batch 1000 loss: 0.33192680720053613
  batch 2000 loss: 0.34202571085238015
  batch 3000 loss: 0.335417492271401
  batch 4000 loss: 0.32720715750049567
  batch 5000 loss: 0.33529670788586374
  batch 6000 loss: 0.3365356587840506
  batch 7000 loss: 0.3293946066937642
  batch 8000 loss: 0.3312329010723479
  batch 9000 loss: 0.31285152013326295
  batch 10000 loss: 0.33454375235643236
  batch 11000 loss: 0.31987216585301215
  batch 12000 loss: 0.31072847432065465
  batch 13000 loss: 0.320197053734606
  batch 14000 loss: 0.3249074105130276
  batch 15000 loss: 0.3234520097374625
LOSS train 0.3234520097374625 valid 0.34419718384742737
EPOCH 4:
  batch 1000 loss: 0.29515260711359226
  batch 2000 loss: 0.2910699634320117
  batch 3000 loss: 0.3052371727841637
  batch 4000 loss: 0.31080969222952265
  batch 5000 loss: 0.32177027653048573
  batch 6000 loss: 0.2946780097327137
  batch 7000 loss: 0.3093485294390557
  batch 8000 loss: 0.3024247188994923
  batch 9000 loss: 0.28978041738236787
  batch 10000 loss: 0.2951672013048219
  batch 11000 loss: 0.297327028882828
  batch 12000 loss: 0.3054323522786508
  batch 13000 loss: 0.29910335543467953
  batch 14000 loss: 0.296728741799634
  batch 15000 loss: 0.30891430737902553
LOSS train 0.30891430737902553 valid 0.3197164535522461
EPOCH 5:
  batch 1000 loss: 0.2831958816282277
  batch 2000 loss: 0.2890166122386254
  batch 3000 loss: 0.2812165444063212
  batch 4000 loss: 0.27710397858938085
  batch 5000 loss: 0.2815308124340081
  batch 6000 loss: 0.2725695745303456
  batch 7000 loss: 0.2726585641002894
  batch 8000 loss: 0.3035050464626984
  batch 9000 loss: 0.2777940203845428
  batch 10000 loss: 0.2878590142988396
  batch 11000 loss: 0.28217975836583353
  batch 12000 loss: 0.28829225361906174
  batch 13000 loss: 0.26233880659297937
  batch 14000 loss: 0.2656326262994826
  batch 15000 loss: 0.2850159434989364
LOSS train 0.2850159434989364 valid 0.2982073724269867

To load a saved version of the model:

saved_model = GarmentClassifier()
saved_model.load_state_dict(torch.load(PATH))

Once you’ve loaded the model, it’s ready for whatever you need it for - more training, inference, or analysis.

Note that if your model has constructor parameters that affect model structure, you’ll need to provide them and configure the model identically to the state in which it was saved.

Other Resources#

Total running time of the script: (2 minutes 59.639 seconds)