Rate this Page

Introduction || Tensors || Autograd || Building Models || TensorBoard Support || Training Models || Model Understanding

Training with PyTorch#

Created On: Nov 30, 2021 | Last Updated: May 31, 2023 | Last Verified: Nov 05, 2024

Follow along with the video below or on youtube.

Introduction#

In past videos, we’ve discussed and demonstrated:

  • Building models with the neural network layers and functions of the torch.nn module

  • The mechanics of automated gradient computation, which is central to gradient-based model training

  • Using TensorBoard to visualize training progress and other activities

In this video, we’ll be adding some new tools to your inventory:

  • We’ll get familiar with the dataset and dataloader abstractions, and how they ease the process of feeding data to your model during a training loop

  • We’ll discuss specific loss functions and when to use them

  • We’ll look at PyTorch optimizers, which implement algorithms to adjust model weights based on the outcome of a loss function

Finally, we’ll pull all of these together and see a full PyTorch training loop in action.

Dataset and DataLoader#

The Dataset and DataLoader classes encapsulate the process of pulling your data from storage and exposing it to your training loop in batches.

The Dataset is responsible for accessing and processing single instances of data.

The DataLoader pulls instances of data from the Dataset (either automatically or with a sampler that you define), collects them in batches, and returns them for consumption by your training loop. The DataLoader works with all kinds of datasets, regardless of the type of data they contain.

For this tutorial, we’ll be using the Fashion-MNIST dataset provided by TorchVision. We use torchvision.transforms.Normalize() to zero-center and normalize the distribution of the image tile content, and download both training and validation data splits.

import torch
import torchvision
import torchvision.transforms as transforms

# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime


transform = transforms.Compose(
    [transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))])

# Create datasets for training & validation, download if necessary
training_set = torchvision.datasets.FashionMNIST('./data', train=True, transform=transform, download=True)
validation_set = torchvision.datasets.FashionMNIST('./data', train=False, transform=transform, download=True)

# Create data loaders for our datasets; shuffle for training, not for validation
training_loader = torch.utils.data.DataLoader(training_set, batch_size=4, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=4, shuffle=False)

# Class labels
classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
        'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')

# Report split sizes
print('Training set has {} instances'.format(len(training_set)))
print('Validation set has {} instances'.format(len(validation_set)))
  0%|          | 0.00/26.4M [00:00<?, ?B/s]
  0%|          | 65.5k/26.4M [00:00<01:12, 365kB/s]
  1%|          | 229k/26.4M [00:00<00:38, 684kB/s]
  3%|▎         | 918k/26.4M [00:00<00:12, 2.11MB/s]
 14%|█▍        | 3.67M/26.4M [00:00<00:03, 7.28MB/s]
 36%|███▋      | 9.60M/26.4M [00:00<00:01, 16.5MB/s]
 59%|█████▉    | 15.5M/26.4M [00:01<00:00, 22.0MB/s]
 82%|████████▏ | 21.6M/26.4M [00:01<00:00, 25.8MB/s]
100%|██████████| 26.4M/26.4M [00:01<00:00, 19.4MB/s]

  0%|          | 0.00/29.5k [00:00<?, ?B/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 328kB/s]

  0%|          | 0.00/4.42M [00:00<?, ?B/s]
  1%|▏         | 65.5k/4.42M [00:00<00:11, 364kB/s]
  5%|▌         | 229k/4.42M [00:00<00:06, 682kB/s]
 20%|██        | 885k/4.42M [00:00<00:01, 2.02MB/s]
 81%|████████  | 3.57M/4.42M [00:00<00:00, 7.09MB/s]
100%|██████████| 4.42M/4.42M [00:00<00:00, 6.10MB/s]

  0%|          | 0.00/5.15k [00:00<?, ?B/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 59.2MB/s]
Training set has 60000 instances
Validation set has 10000 instances

As always, let’s visualize the data as a sanity check:

import matplotlib.pyplot as plt
import numpy as np

# Helper function for inline image display
def matplotlib_imshow(img, one_channel=False):
    if one_channel:
        img = img.mean(dim=0)
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    if one_channel:
        plt.imshow(npimg, cmap="Greys")
    else:
        plt.imshow(np.transpose(npimg, (1, 2, 0)))

dataiter = iter(training_loader)
images, labels = next(dataiter)

# Create a grid from the images and show them
img_grid = torchvision.utils.make_grid(images)
matplotlib_imshow(img_grid, one_channel=True)
print('  '.join(classes[labels[j]] for j in range(4)))
trainingyt
Pullover  Bag  Ankle Boot  Dress

The Model#

The model we’ll use in this example is a variant of LeNet-5 - it should be familiar if you’ve watched the previous videos in this series.

import torch.nn as nn
import torch.nn.functional as F

# PyTorch models inherit from torch.nn.Module
class GarmentClassifier(nn.Module):
    def __init__(self):
        super(GarmentClassifier, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


model = GarmentClassifier()

Loss Function#

For this example, we’ll be using a cross-entropy loss. For demonstration purposes, we’ll create batches of dummy output and label values, run them through the loss function, and examine the result.

loss_fn = torch.nn.CrossEntropyLoss()

# NB: Loss functions expect data in batches, so we're creating batches of 4
# Represents the model's confidence in each of the 10 classes for a given input
dummy_outputs = torch.rand(4, 10)
# Represents the correct class among the 10 being tested
dummy_labels = torch.tensor([1, 5, 3, 7])

print(dummy_outputs)
print(dummy_labels)

loss = loss_fn(dummy_outputs, dummy_labels)
print('Total loss for this batch: {}'.format(loss.item()))
tensor([[0.7528, 0.1049, 0.2693, 0.7357, 0.7790, 0.0190, 0.9424, 0.4964, 0.0902,
         0.3291],
        [0.4673, 0.3194, 0.3383, 0.9915, 0.1708, 0.6587, 0.0029, 0.9962, 0.8315,
         0.2306],
        [0.7995, 0.7187, 0.2195, 0.0611, 0.3484, 0.3580, 0.1665, 0.6033, 0.4798,
         0.8635],
        [0.9398, 0.9129, 0.0150, 0.3506, 0.0445, 0.8106, 0.9968, 0.9955, 0.8752,
         0.8417]])
tensor([1, 5, 3, 7])
Total loss for this batch: 2.4208834171295166

Optimizer#

For this example, we’ll be using simple stochastic gradient descent with momentum.

It can be instructive to try some variations on this optimization scheme:

  • Learning rate determines the size of the steps the optimizer takes. What does a different learning rate do to the your training results, in terms of accuracy and convergence time?

  • Momentum nudges the optimizer in the direction of strongest gradient over multiple steps. What does changing this value do to your results?

  • Try some different optimization algorithms, such as averaged SGD, Adagrad, or Adam. How do your results differ?

# Optimizers specified in the torch.optim package
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

The Training Loop#

Below, we have a function that performs one training epoch. It enumerates data from the DataLoader, and on each pass of the loop does the following:

  • Gets a batch of training data from the DataLoader

  • Zeros the optimizer’s gradients

  • Performs an inference - that is, gets predictions from the model for an input batch

  • Calculates the loss for that set of predictions vs. the labels on the dataset

  • Calculates the backward gradients over the learning weights

  • Tells the optimizer to perform one learning step - that is, adjust the model’s learning weights based on the observed gradients for this batch, according to the optimization algorithm we chose

  • It reports on the loss for every 1000 batches.

  • Finally, it reports the average per-batch loss for the last 1000 batches, for comparison with a validation run

def train_one_epoch(epoch_index, tb_writer):
    running_loss = 0.
    last_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(training_loader):
        # Every data instance is an input + label pair
        inputs, labels = data

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs = model(inputs)

        # Compute the loss and its gradients
        loss = loss_fn(outputs, labels)
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        if i % 1000 == 999:
            last_loss = running_loss / 1000 # loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            tb_x = epoch_index * len(training_loader) + i + 1
            tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            running_loss = 0.

    return last_loss

Per-Epoch Activity#

There are a couple of things we’ll want to do once per epoch:

  • Perform validation by checking our relative loss on a set of data that was not used for training, and report this

  • Save a copy of the model

Here, we’ll do our reporting in TensorBoard. This will require going to the command line to start TensorBoard, and opening it in another browser tab.

# Initializing in a separate cell so we can easily add more epochs to the same run
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp))
epoch_number = 0

EPOCHS = 5

best_vloss = 1_000_000.

for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch_number + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch(epoch_number, writer)


    running_vloss = 0.0
    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    model.eval()

    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        for i, vdata in enumerate(validation_loader):
            vinputs, vlabels = vdata
            voutputs = model(vinputs)
            vloss = loss_fn(voutputs, vlabels)
            running_vloss += vloss

    avg_vloss = running_vloss / (i + 1)
    print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

    # Log the running loss averaged per batch
    # for both training and validation
    writer.add_scalars('Training vs. Validation Loss',
                    { 'Training' : avg_loss, 'Validation' : avg_vloss },
                    epoch_number + 1)
    writer.flush()

    # Track best performance, and save the model's state
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = 'model_{}_{}'.format(timestamp, epoch_number)
        torch.save(model.state_dict(), model_path)

    epoch_number += 1
EPOCH 1:
  batch 1000 loss: 1.8180003824532032
  batch 2000 loss: 0.8957754331827164
  batch 3000 loss: 0.741791602266021
  batch 4000 loss: 0.6449173951987177
  batch 5000 loss: 0.6043300562414806
  batch 6000 loss: 0.5600329208387993
  batch 7000 loss: 0.5610174404948484
  batch 8000 loss: 0.5391112232778688
  batch 9000 loss: 0.49356041224929503
  batch 10000 loss: 0.4750006023424212
  batch 11000 loss: 0.47520611739280866
  batch 12000 loss: 0.45327228994260077
  batch 13000 loss: 0.4309118041661568
  batch 14000 loss: 0.43663130749110135
  batch 15000 loss: 0.45419479672331364
LOSS train 0.45419479672331364 valid 0.42852601408958435
EPOCH 2:
  batch 1000 loss: 0.4146793999471702
  batch 2000 loss: 0.3971446214191965
  batch 3000 loss: 0.39344882845913526
  batch 4000 loss: 0.40533752440224635
  batch 5000 loss: 0.38281099223109777
  batch 6000 loss: 0.40937822492184933
  batch 7000 loss: 0.34685169195458365
  batch 8000 loss: 0.380262188786146
  batch 9000 loss: 0.36511920924199515
  batch 10000 loss: 0.36148983901017345
  batch 11000 loss: 0.36713312672358006
  batch 12000 loss: 0.36210009656846526
  batch 13000 loss: 0.34765562966381547
  batch 14000 loss: 0.3762075893210713
  batch 15000 loss: 0.35051395519357176
LOSS train 0.35051395519357176 valid 0.3601304590702057
EPOCH 3:
  batch 1000 loss: 0.3406964710156899
  batch 2000 loss: 0.346449927648835
  batch 3000 loss: 0.31156337086665736
  batch 4000 loss: 0.3395563610608224
  batch 5000 loss: 0.3243687777138548
  batch 6000 loss: 0.33435277505280103
  batch 7000 loss: 0.3491077534397482
  batch 8000 loss: 0.321636503800386
  batch 9000 loss: 0.29387344657287756
  batch 10000 loss: 0.32132642339486484
  batch 11000 loss: 0.3446933602282952
  batch 12000 loss: 0.3091468107325199
  batch 13000 loss: 0.32136326609969273
  batch 14000 loss: 0.31983308911712083
  batch 15000 loss: 0.3228459794683731
LOSS train 0.3228459794683731 valid 0.34866663813591003
EPOCH 4:
  batch 1000 loss: 0.3093333835767262
  batch 2000 loss: 0.29871428542421197
  batch 3000 loss: 0.33486319116083907
  batch 4000 loss: 0.30480184700866814
  batch 5000 loss: 0.3016270888010331
  batch 6000 loss: 0.30586647188165805
  batch 7000 loss: 0.2702475314392941
  batch 8000 loss: 0.28462565018634634
  batch 9000 loss: 0.318295512925848
  batch 10000 loss: 0.300293825531262
  batch 11000 loss: 0.2823062967541264
  batch 12000 loss: 0.29897188573934547
  batch 13000 loss: 0.276616980892104
  batch 14000 loss: 0.2911908489668567
  batch 15000 loss: 0.2921275068801697
LOSS train 0.2921275068801697 valid 0.30819186568260193
EPOCH 5:
  batch 1000 loss: 0.27717354378083836
  batch 2000 loss: 0.2679578961185471
  batch 3000 loss: 0.27298799382541256
  batch 4000 loss: 0.2856498993916393
  batch 5000 loss: 0.2585386563991997
  batch 6000 loss: 0.27737813394927796
  batch 7000 loss: 0.2762562871880291
  batch 8000 loss: 0.27440893485613926
  batch 9000 loss: 0.2854745021615454
  batch 10000 loss: 0.29342940711366283
  batch 11000 loss: 0.2676458901484548
  batch 12000 loss: 0.29326289563301544
  batch 13000 loss: 0.2927326946757194
  batch 14000 loss: 0.2713185495002508
  batch 15000 loss: 0.2670032520847344
LOSS train 0.2670032520847344 valid 0.3067055940628052

To load a saved version of the model:

saved_model = GarmentClassifier()
saved_model.load_state_dict(torch.load(PATH))

Once you’ve loaded the model, it’s ready for whatever you need it for - more training, inference, or analysis.

Note that if your model has constructor parameters that affect model structure, you’ll need to provide them and configure the model identically to the state in which it was saved.

Other Resources#

Total running time of the script: (3 minutes 1.679 seconds)