Rate this Page

Introduction || Tensors || Autograd || Building Models || TensorBoard Support || Training Models || Model Understanding

Training with PyTorch#

Created On: Nov 30, 2021 | Last Updated: May 31, 2023 | Last Verified: Nov 05, 2024

Follow along with the video below or on youtube.

Introduction#

In past videos, we’ve discussed and demonstrated:

  • Building models with the neural network layers and functions of the torch.nn module

  • The mechanics of automated gradient computation, which is central to gradient-based model training

  • Using TensorBoard to visualize training progress and other activities

In this video, we’ll be adding some new tools to your inventory:

  • We’ll get familiar with the dataset and dataloader abstractions, and how they ease the process of feeding data to your model during a training loop

  • We’ll discuss specific loss functions and when to use them

  • We’ll look at PyTorch optimizers, which implement algorithms to adjust model weights based on the outcome of a loss function

Finally, we’ll pull all of these together and see a full PyTorch training loop in action.

Dataset and DataLoader#

The Dataset and DataLoader classes encapsulate the process of pulling your data from storage and exposing it to your training loop in batches.

The Dataset is responsible for accessing and processing single instances of data.

The DataLoader pulls instances of data from the Dataset (either automatically or with a sampler that you define), collects them in batches, and returns them for consumption by your training loop. The DataLoader works with all kinds of datasets, regardless of the type of data they contain.

For this tutorial, we’ll be using the Fashion-MNIST dataset provided by TorchVision. We use torchvision.transforms.Normalize() to zero-center and normalize the distribution of the image tile content, and download both training and validation data splits.

import torch
import torchvision
import torchvision.transforms as transforms

# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime


transform = transforms.Compose(
    [transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))])

# Create datasets for training & validation, download if necessary
training_set = torchvision.datasets.FashionMNIST('./data', train=True, transform=transform, download=True)
validation_set = torchvision.datasets.FashionMNIST('./data', train=False, transform=transform, download=True)

# Create data loaders for our datasets; shuffle for training, not for validation
training_loader = torch.utils.data.DataLoader(training_set, batch_size=4, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=4, shuffle=False)

# Class labels
classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
        'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')

# Report split sizes
print('Training set has {} instances'.format(len(training_set)))
print('Validation set has {} instances'.format(len(validation_set)))
  0%|          | 0.00/26.4M [00:00<?, ?B/s]
  0%|          | 65.5k/26.4M [00:00<01:12, 361kB/s]
  1%|          | 229k/26.4M [00:00<00:38, 677kB/s]
  3%|▎         | 918k/26.4M [00:00<00:12, 2.09MB/s]
 14%|█▍        | 3.67M/26.4M [00:00<00:03, 7.22MB/s]
 37%|███▋      | 9.80M/26.4M [00:00<00:00, 16.7MB/s]
 47%|████▋     | 12.3M/26.4M [00:01<00:00, 18.5MB/s]
 58%|█████▊    | 15.4M/26.4M [00:01<00:00, 18.3MB/s]
 80%|███████▉  | 21.1M/26.4M [00:01<00:00, 26.9MB/s]
 94%|█████████▍| 24.8M/26.4M [00:01<00:00, 24.9MB/s]
100%|██████████| 26.4M/26.4M [00:01<00:00, 18.0MB/s]

  0%|          | 0.00/29.5k [00:00<?, ?B/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 322kB/s]

  0%|          | 0.00/4.42M [00:00<?, ?B/s]
  1%|▏         | 65.5k/4.42M [00:00<00:12, 362kB/s]
  5%|▌         | 229k/4.42M [00:00<00:06, 681kB/s]
 21%|██        | 918k/4.42M [00:00<00:01, 2.11MB/s]
 83%|████████▎ | 3.67M/4.42M [00:00<00:00, 7.27MB/s]
100%|██████████| 4.42M/4.42M [00:00<00:00, 6.09MB/s]

  0%|          | 0.00/5.15k [00:00<?, ?B/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 55.5MB/s]
Training set has 60000 instances
Validation set has 10000 instances

As always, let’s visualize the data as a sanity check:

import matplotlib.pyplot as plt
import numpy as np

# Helper function for inline image display
def matplotlib_imshow(img, one_channel=False):
    if one_channel:
        img = img.mean(dim=0)
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    if one_channel:
        plt.imshow(npimg, cmap="Greys")
    else:
        plt.imshow(np.transpose(npimg, (1, 2, 0)))

dataiter = iter(training_loader)
images, labels = next(dataiter)

# Create a grid from the images and show them
img_grid = torchvision.utils.make_grid(images)
matplotlib_imshow(img_grid, one_channel=True)
print('  '.join(classes[labels[j]] for j in range(4)))
trainingyt
Bag  Coat  Bag  Pullover

The Model#

The model we’ll use in this example is a variant of LeNet-5 - it should be familiar if you’ve watched the previous videos in this series.

import torch.nn as nn
import torch.nn.functional as F

# PyTorch models inherit from torch.nn.Module
class GarmentClassifier(nn.Module):
    def __init__(self):
        super(GarmentClassifier, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


model = GarmentClassifier()

Loss Function#

For this example, we’ll be using a cross-entropy loss. For demonstration purposes, we’ll create batches of dummy output and label values, run them through the loss function, and examine the result.

loss_fn = torch.nn.CrossEntropyLoss()

# NB: Loss functions expect data in batches, so we're creating batches of 4
# Represents the model's confidence in each of the 10 classes for a given input
dummy_outputs = torch.rand(4, 10)
# Represents the correct class among the 10 being tested
dummy_labels = torch.tensor([1, 5, 3, 7])

print(dummy_outputs)
print(dummy_labels)

loss = loss_fn(dummy_outputs, dummy_labels)
print('Total loss for this batch: {}'.format(loss.item()))
tensor([[0.1458, 0.6516, 0.4155, 0.6573, 0.1644, 0.0487, 0.4766, 0.2807, 0.2976,
         0.4981],
        [0.6846, 0.6619, 0.6084, 0.9605, 0.8064, 0.8921, 0.8410, 0.8679, 0.1564,
         0.9402],
        [0.6935, 0.3799, 0.3758, 0.5474, 0.0831, 0.4646, 0.8744, 0.4167, 0.6153,
         0.8313],
        [0.2894, 0.6151, 0.3496, 0.7387, 0.6422, 0.7366, 0.3435, 0.2882, 0.9476,
         0.2568]])
tensor([1, 5, 3, 7])
Total loss for this batch: 2.270169973373413

Optimizer#

For this example, we’ll be using simple stochastic gradient descent with momentum.

It can be instructive to try some variations on this optimization scheme:

  • Learning rate determines the size of the steps the optimizer takes. What does a different learning rate do to the your training results, in terms of accuracy and convergence time?

  • Momentum nudges the optimizer in the direction of strongest gradient over multiple steps. What does changing this value do to your results?

  • Try some different optimization algorithms, such as averaged SGD, Adagrad, or Adam. How do your results differ?

# Optimizers specified in the torch.optim package
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

The Training Loop#

Below, we have a function that performs one training epoch. It enumerates data from the DataLoader, and on each pass of the loop does the following:

  • Gets a batch of training data from the DataLoader

  • Zeros the optimizer’s gradients

  • Performs an inference - that is, gets predictions from the model for an input batch

  • Calculates the loss for that set of predictions vs. the labels on the dataset

  • Calculates the backward gradients over the learning weights

  • Tells the optimizer to perform one learning step - that is, adjust the model’s learning weights based on the observed gradients for this batch, according to the optimization algorithm we chose

  • It reports on the loss for every 1000 batches.

  • Finally, it reports the average per-batch loss for the last 1000 batches, for comparison with a validation run

def train_one_epoch(epoch_index, tb_writer):
    running_loss = 0.
    last_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(training_loader):
        # Every data instance is an input + label pair
        inputs, labels = data

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs = model(inputs)

        # Compute the loss and its gradients
        loss = loss_fn(outputs, labels)
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        if i % 1000 == 999:
            last_loss = running_loss / 1000 # loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            tb_x = epoch_index * len(training_loader) + i + 1
            tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            running_loss = 0.

    return last_loss

Per-Epoch Activity#

There are a couple of things we’ll want to do once per epoch:

  • Perform validation by checking our relative loss on a set of data that was not used for training, and report this

  • Save a copy of the model

Here, we’ll do our reporting in TensorBoard. This will require going to the command line to start TensorBoard, and opening it in another browser tab.

# Initializing in a separate cell so we can easily add more epochs to the same run
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp))
epoch_number = 0

EPOCHS = 5

best_vloss = 1_000_000.

for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch_number + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch(epoch_number, writer)


    running_vloss = 0.0
    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    model.eval()

    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        for i, vdata in enumerate(validation_loader):
            vinputs, vlabels = vdata
            voutputs = model(vinputs)
            vloss = loss_fn(voutputs, vlabels)
            running_vloss += vloss

    avg_vloss = running_vloss / (i + 1)
    print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

    # Log the running loss averaged per batch
    # for both training and validation
    writer.add_scalars('Training vs. Validation Loss',
                    { 'Training' : avg_loss, 'Validation' : avg_vloss },
                    epoch_number + 1)
    writer.flush()

    # Track best performance, and save the model's state
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = 'model_{}_{}'.format(timestamp, epoch_number)
        torch.save(model.state_dict(), model_path)

    epoch_number += 1
EPOCH 1:
  batch 1000 loss: 2.1626133951246738
  batch 2000 loss: 0.9745002372637391
  batch 3000 loss: 0.7352404489740729
  batch 4000 loss: 0.6727571301721036
  batch 5000 loss: 0.6020912300148048
  batch 6000 loss: 0.5561846439801157
  batch 7000 loss: 0.5287796507515014
  batch 8000 loss: 0.501840026800055
  batch 9000 loss: 0.5147234166720882
  batch 10000 loss: 0.4789069440057501
  batch 11000 loss: 0.4642180380858481
  batch 12000 loss: 0.4668008829996688
  batch 13000 loss: 0.44677565061626956
  batch 14000 loss: 0.42972591010754696
  batch 15000 loss: 0.42956455090083184
LOSS train 0.42956455090083184 valid 0.4397105872631073
EPOCH 2:
  batch 1000 loss: 0.415719639013405
  batch 2000 loss: 0.40858568280225155
  batch 3000 loss: 0.3870634801845299
  batch 4000 loss: 0.399076723834849
  batch 5000 loss: 0.38554254713084085
  batch 6000 loss: 0.3914544439036981
  batch 7000 loss: 0.37822988635860383
  batch 8000 loss: 0.3661753419930174
  batch 9000 loss: 0.3562473828648217
  batch 10000 loss: 0.36509081625379625
  batch 11000 loss: 0.36181802575226174
  batch 12000 loss: 0.3447543090864783
  batch 13000 loss: 0.3322899040741322
  batch 14000 loss: 0.3669428357169963
  batch 15000 loss: 0.34384529662012936
LOSS train 0.34384529662012936 valid 0.3930756449699402
EPOCH 3:
  batch 1000 loss: 0.33305908631574127
  batch 2000 loss: 0.34648251156628246
  batch 3000 loss: 0.3326352587278234
  batch 4000 loss: 0.33988424329501865
  batch 5000 loss: 0.31139411728776756
  batch 6000 loss: 0.33292682037170745
  batch 7000 loss: 0.33184790199639974
  batch 8000 loss: 0.3425309301953821
  batch 9000 loss: 0.31668488782917303
  batch 10000 loss: 0.31518867740477435
  batch 11000 loss: 0.31544839551368203
  batch 12000 loss: 0.30939525527498335
  batch 13000 loss: 0.3009382589287925
  batch 14000 loss: 0.3184769766079844
  batch 15000 loss: 0.32007758843247214
LOSS train 0.32007758843247214 valid 0.3472983241081238
EPOCH 4:
  batch 1000 loss: 0.3029580612979262
  batch 2000 loss: 0.2896642429686326
  batch 3000 loss: 0.31907536742377124
  batch 4000 loss: 0.29765302756043094
  batch 5000 loss: 0.2973705656676611
  batch 6000 loss: 0.29400796061562867
  batch 7000 loss: 0.28913413904867774
  batch 8000 loss: 0.29533820189339166
  batch 9000 loss: 0.2959047243149398
  batch 10000 loss: 0.3048672123772267
  batch 11000 loss: 0.28920165327872016
  batch 12000 loss: 0.2964483002234265
  batch 13000 loss: 0.3016540655086428
  batch 14000 loss: 0.2911894151992674
  batch 15000 loss: 0.2685066594746313
LOSS train 0.2685066594746313 valid 0.3344765901565552
EPOCH 5:
  batch 1000 loss: 0.2549045204511967
  batch 2000 loss: 0.2778276502187814
  batch 3000 loss: 0.27435155522457355
  batch 4000 loss: 0.2801725401299882
  batch 5000 loss: 0.26628558637754757
  batch 6000 loss: 0.2661942468724228
  batch 7000 loss: 0.2863066789020231
  batch 8000 loss: 0.2792999133357134
  batch 9000 loss: 0.2961141931802413
  batch 10000 loss: 0.2848576120319158
  batch 11000 loss: 0.28563465925953824
  batch 12000 loss: 0.269254339265437
  batch 13000 loss: 0.2877368157287201
  batch 14000 loss: 0.29883485360251505
  batch 15000 loss: 0.296440352970385
LOSS train 0.296440352970385 valid 0.3300156593322754

To load a saved version of the model:

saved_model = GarmentClassifier()
saved_model.load_state_dict(torch.load(PATH))

Once you’ve loaded the model, it’s ready for whatever you need it for - more training, inference, or analysis.

Note that if your model has constructor parameters that affect model structure, you’ll need to provide them and configure the model identically to the state in which it was saved.

Other Resources#

Total running time of the script: (2 minutes 59.547 seconds)