Rate this Page

Introduction || Tensors || Autograd || Building Models || TensorBoard Support || Training Models || Model Understanding

Training with PyTorch#

Created On: Nov 30, 2021 | Last Updated: May 31, 2023 | Last Verified: Nov 05, 2024

Follow along with the video below or on youtube.

Introduction#

In past videos, we’ve discussed and demonstrated:

  • Building models with the neural network layers and functions of the torch.nn module

  • The mechanics of automated gradient computation, which is central to gradient-based model training

  • Using TensorBoard to visualize training progress and other activities

In this video, we’ll be adding some new tools to your inventory:

  • We’ll get familiar with the dataset and dataloader abstractions, and how they ease the process of feeding data to your model during a training loop

  • We’ll discuss specific loss functions and when to use them

  • We’ll look at PyTorch optimizers, which implement algorithms to adjust model weights based on the outcome of a loss function

Finally, we’ll pull all of these together and see a full PyTorch training loop in action.

Dataset and DataLoader#

The Dataset and DataLoader classes encapsulate the process of pulling your data from storage and exposing it to your training loop in batches.

The Dataset is responsible for accessing and processing single instances of data.

The DataLoader pulls instances of data from the Dataset (either automatically or with a sampler that you define), collects them in batches, and returns them for consumption by your training loop. The DataLoader works with all kinds of datasets, regardless of the type of data they contain.

For this tutorial, we’ll be using the Fashion-MNIST dataset provided by TorchVision. We use torchvision.transforms.Normalize() to zero-center and normalize the distribution of the image tile content, and download both training and validation data splits.

import torch
import torchvision
import torchvision.transforms as transforms

# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime


transform = transforms.Compose(
    [transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))])

# Create datasets for training & validation, download if necessary
training_set = torchvision.datasets.FashionMNIST('./data', train=True, transform=transform, download=True)
validation_set = torchvision.datasets.FashionMNIST('./data', train=False, transform=transform, download=True)

# Create data loaders for our datasets; shuffle for training, not for validation
training_loader = torch.utils.data.DataLoader(training_set, batch_size=4, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=4, shuffle=False)

# Class labels
classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
        'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')

# Report split sizes
print('Training set has {} instances'.format(len(training_set)))
print('Validation set has {} instances'.format(len(validation_set)))
  0%|          | 0.00/26.4M [00:00<?, ?B/s]
  0%|          | 65.5k/26.4M [00:00<01:12, 362kB/s]
  1%|          | 229k/26.4M [00:00<00:38, 678kB/s]
  3%|▎         | 918k/26.4M [00:00<00:12, 2.10MB/s]
 14%|█▍        | 3.67M/26.4M [00:00<00:03, 7.23MB/s]
 36%|███▌      | 9.50M/26.4M [00:00<00:01, 16.2MB/s]
 55%|█████▌    | 14.6M/26.4M [00:01<00:00, 24.0MB/s]
 69%|██████▉   | 18.2M/26.4M [00:01<00:00, 22.8MB/s]
 88%|████████▊ | 23.2M/26.4M [00:01<00:00, 28.9MB/s]
100%|██████████| 26.4M/26.4M [00:01<00:00, 19.2MB/s]

  0%|          | 0.00/29.5k [00:00<?, ?B/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 325kB/s]

  0%|          | 0.00/4.42M [00:00<?, ?B/s]
  1%|▏         | 65.5k/4.42M [00:00<00:12, 360kB/s]
  5%|▌         | 229k/4.42M [00:00<00:06, 676kB/s]
 21%|██        | 918k/4.42M [00:00<00:01, 2.09MB/s]
 83%|████████▎ | 3.67M/4.42M [00:00<00:00, 7.22MB/s]
100%|██████████| 4.42M/4.42M [00:00<00:00, 6.04MB/s]

  0%|          | 0.00/5.15k [00:00<?, ?B/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 58.4MB/s]
Training set has 60000 instances
Validation set has 10000 instances

As always, let’s visualize the data as a sanity check:

import matplotlib.pyplot as plt
import numpy as np

# Helper function for inline image display
def matplotlib_imshow(img, one_channel=False):
    if one_channel:
        img = img.mean(dim=0)
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    if one_channel:
        plt.imshow(npimg, cmap="Greys")
    else:
        plt.imshow(np.transpose(npimg, (1, 2, 0)))

dataiter = iter(training_loader)
images, labels = next(dataiter)

# Create a grid from the images and show them
img_grid = torchvision.utils.make_grid(images)
matplotlib_imshow(img_grid, one_channel=True)
print('  '.join(classes[labels[j]] for j in range(4)))
trainingyt
Sandal  Dress  Bag  Pullover

The Model#

The model we’ll use in this example is a variant of LeNet-5 - it should be familiar if you’ve watched the previous videos in this series.

import torch.nn as nn
import torch.nn.functional as F

# PyTorch models inherit from torch.nn.Module
class GarmentClassifier(nn.Module):
    def __init__(self):
        super(GarmentClassifier, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


model = GarmentClassifier()

Loss Function#

For this example, we’ll be using a cross-entropy loss. For demonstration purposes, we’ll create batches of dummy output and label values, run them through the loss function, and examine the result.

loss_fn = torch.nn.CrossEntropyLoss()

# NB: Loss functions expect data in batches, so we're creating batches of 4
# Represents the model's confidence in each of the 10 classes for a given input
dummy_outputs = torch.rand(4, 10)
# Represents the correct class among the 10 being tested
dummy_labels = torch.tensor([1, 5, 3, 7])

print(dummy_outputs)
print(dummy_labels)

loss = loss_fn(dummy_outputs, dummy_labels)
print('Total loss for this batch: {}'.format(loss.item()))
tensor([[0.2269, 0.5681, 0.9649, 0.0939, 0.0386, 0.6477, 0.9346, 0.8101, 0.8605,
         0.1474],
        [0.7243, 0.6807, 0.0226, 0.7000, 0.4934, 0.9603, 0.7426, 0.1386, 0.7102,
         0.2794],
        [0.5429, 0.9261, 0.3934, 0.5579, 0.7541, 0.9335, 0.2773, 0.7657, 0.2423,
         0.4768],
        [0.3698, 0.1863, 0.4421, 0.5420, 0.8643, 0.6174, 0.2797, 0.9907, 0.3766,
         0.1664]])
tensor([1, 5, 3, 7])
Total loss for this batch: 2.11007022857666

Optimizer#

For this example, we’ll be using simple stochastic gradient descent with momentum.

It can be instructive to try some variations on this optimization scheme:

  • Learning rate determines the size of the steps the optimizer takes. What does a different learning rate do to the your training results, in terms of accuracy and convergence time?

  • Momentum nudges the optimizer in the direction of strongest gradient over multiple steps. What does changing this value do to your results?

  • Try some different optimization algorithms, such as averaged SGD, Adagrad, or Adam. How do your results differ?

# Optimizers specified in the torch.optim package
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

The Training Loop#

Below, we have a function that performs one training epoch. It enumerates data from the DataLoader, and on each pass of the loop does the following:

  • Gets a batch of training data from the DataLoader

  • Zeros the optimizer’s gradients

  • Performs an inference - that is, gets predictions from the model for an input batch

  • Calculates the loss for that set of predictions vs. the labels on the dataset

  • Calculates the backward gradients over the learning weights

  • Tells the optimizer to perform one learning step - that is, adjust the model’s learning weights based on the observed gradients for this batch, according to the optimization algorithm we chose

  • It reports on the loss for every 1000 batches.

  • Finally, it reports the average per-batch loss for the last 1000 batches, for comparison with a validation run

def train_one_epoch(epoch_index, tb_writer):
    running_loss = 0.
    last_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(training_loader):
        # Every data instance is an input + label pair
        inputs, labels = data

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs = model(inputs)

        # Compute the loss and its gradients
        loss = loss_fn(outputs, labels)
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        if i % 1000 == 999:
            last_loss = running_loss / 1000 # loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            tb_x = epoch_index * len(training_loader) + i + 1
            tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            running_loss = 0.

    return last_loss

Per-Epoch Activity#

There are a couple of things we’ll want to do once per epoch:

  • Perform validation by checking our relative loss on a set of data that was not used for training, and report this

  • Save a copy of the model

Here, we’ll do our reporting in TensorBoard. This will require going to the command line to start TensorBoard, and opening it in another browser tab.

# Initializing in a separate cell so we can easily add more epochs to the same run
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp))
epoch_number = 0

EPOCHS = 5

best_vloss = 1_000_000.

for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch_number + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch(epoch_number, writer)


    running_vloss = 0.0
    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    model.eval()

    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        for i, vdata in enumerate(validation_loader):
            vinputs, vlabels = vdata
            voutputs = model(vinputs)
            vloss = loss_fn(voutputs, vlabels)
            running_vloss += vloss

    avg_vloss = running_vloss / (i + 1)
    print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

    # Log the running loss averaged per batch
    # for both training and validation
    writer.add_scalars('Training vs. Validation Loss',
                    { 'Training' : avg_loss, 'Validation' : avg_vloss },
                    epoch_number + 1)
    writer.flush()

    # Track best performance, and save the model's state
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = 'model_{}_{}'.format(timestamp, epoch_number)
        torch.save(model.state_dict(), model_path)

    epoch_number += 1
EPOCH 1:
  batch 1000 loss: 1.7667355205565691
  batch 2000 loss: 0.8388617958407849
  batch 3000 loss: 0.6968825845010579
  batch 4000 loss: 0.6204483299427666
  batch 5000 loss: 0.5817996844442096
  batch 6000 loss: 0.5788261932777241
  batch 7000 loss: 0.5229650802640244
  batch 8000 loss: 0.5200583340118173
  batch 9000 loss: 0.502049155228764
  batch 10000 loss: 0.4928166435174935
  batch 11000 loss: 0.45396994171477856
  batch 12000 loss: 0.45425144100922626
  batch 13000 loss: 0.4254072148703272
  batch 14000 loss: 0.4189532391532557
  batch 15000 loss: 0.4366461548706866
LOSS train 0.4366461548706866 valid 0.43579769134521484
EPOCH 2:
  batch 1000 loss: 0.42614812783067463
  batch 2000 loss: 0.3911532442855532
  batch 3000 loss: 0.4026775339449523
  batch 4000 loss: 0.4103510719244368
  batch 5000 loss: 0.3840216699426528
  batch 6000 loss: 0.38295095966989173
  batch 7000 loss: 0.36448468580067855
  batch 8000 loss: 0.39291608055040705
  batch 9000 loss: 0.3580735184801306
  batch 10000 loss: 0.3762704997290275
  batch 11000 loss: 0.3728049092126894
  batch 12000 loss: 0.3673193751628569
  batch 13000 loss: 0.35563136918772215
  batch 14000 loss: 0.36428551416122357
  batch 15000 loss: 0.3699195384653867
LOSS train 0.3699195384653867 valid 0.4358103573322296
EPOCH 3:
  batch 1000 loss: 0.3383971535017481
  batch 2000 loss: 0.3306581295079959
  batch 3000 loss: 0.33101004558702696
  batch 4000 loss: 0.33032484355336783
  batch 5000 loss: 0.3601224446147389
  batch 6000 loss: 0.3339655427073012
  batch 7000 loss: 0.34906126738211607
  batch 8000 loss: 0.3442564363214478
  batch 9000 loss: 0.3363055289789918
  batch 10000 loss: 0.31391799155683836
  batch 11000 loss: 0.34703895669064516
  batch 12000 loss: 0.33964774653968016
  batch 13000 loss: 0.2981129790644336
  batch 14000 loss: 0.316912189343615
  batch 15000 loss: 0.34591224429305295
LOSS train 0.34591224429305295 valid 0.3513065278530121
EPOCH 4:
  batch 1000 loss: 0.30626894964195706
  batch 2000 loss: 0.30347760828072023
  batch 3000 loss: 0.3099514487994893
  batch 4000 loss: 0.3248301664618775
  batch 5000 loss: 0.32771538996655725
  batch 6000 loss: 0.30232768331893023
  batch 7000 loss: 0.2970043121685194
  batch 8000 loss: 0.3140870202671422
  batch 9000 loss: 0.29683693973610936
  batch 10000 loss: 0.3023065453028976
  batch 11000 loss: 0.3014863682966534
  batch 12000 loss: 0.3018096925262289
  batch 13000 loss: 0.3195195010416719
  batch 14000 loss: 0.30384252646974347
  batch 15000 loss: 0.3014895583855396
LOSS train 0.3014895583855396 valid 0.3247852921485901
EPOCH 5:
  batch 1000 loss: 0.2849305837958673
  batch 2000 loss: 0.28834275505393997
  batch 3000 loss: 0.28477707908125377
  batch 4000 loss: 0.3016129176657341
  batch 5000 loss: 0.2862960026530309
  batch 6000 loss: 0.2899487989548343
  batch 7000 loss: 0.31688403315129837
  batch 8000 loss: 0.2856321340752111
  batch 9000 loss: 0.27428792149353104
  batch 10000 loss: 0.29562672955554264
  batch 11000 loss: 0.2844663936333454
  batch 12000 loss: 0.3013684300243913
  batch 13000 loss: 0.28001616359367015
  batch 14000 loss: 0.297360300047716
  batch 15000 loss: 0.2720410674651648
LOSS train 0.2720410674651648 valid 0.3180714249610901

To load a saved version of the model:

saved_model = GarmentClassifier()
saved_model.load_state_dict(torch.load(PATH))

Once you’ve loaded the model, it’s ready for whatever you need it for - more training, inference, or analysis.

Note that if your model has constructor parameters that affect model structure, you’ll need to provide them and configure the model identically to the state in which it was saved.

Other Resources#

Total running time of the script: (3 minutes 0.799 seconds)