Rate this Page

Introduction || Tensors || Autograd || Building Models || TensorBoard Support || Training Models || Model Understanding

Training with PyTorch#

Created On: Nov 30, 2021 | Last Updated: May 31, 2023 | Last Verified: Nov 05, 2024

Follow along with the video below or on youtube.

Introduction#

In past videos, we’ve discussed and demonstrated:

  • Building models with the neural network layers and functions of the torch.nn module

  • The mechanics of automated gradient computation, which is central to gradient-based model training

  • Using TensorBoard to visualize training progress and other activities

In this video, we’ll be adding some new tools to your inventory:

  • We’ll get familiar with the dataset and dataloader abstractions, and how they ease the process of feeding data to your model during a training loop

  • We’ll discuss specific loss functions and when to use them

  • We’ll look at PyTorch optimizers, which implement algorithms to adjust model weights based on the outcome of a loss function

Finally, we’ll pull all of these together and see a full PyTorch training loop in action.

Dataset and DataLoader#

The Dataset and DataLoader classes encapsulate the process of pulling your data from storage and exposing it to your training loop in batches.

The Dataset is responsible for accessing and processing single instances of data.

The DataLoader pulls instances of data from the Dataset (either automatically or with a sampler that you define), collects them in batches, and returns them for consumption by your training loop. The DataLoader works with all kinds of datasets, regardless of the type of data they contain.

For this tutorial, we’ll be using the Fashion-MNIST dataset provided by TorchVision. We use torchvision.transforms.Normalize() to zero-center and normalize the distribution of the image tile content, and download both training and validation data splits.

import torch
import torchvision
import torchvision.transforms as transforms

# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime


transform = transforms.Compose(
    [transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))])

# Create datasets for training & validation, download if necessary
training_set = torchvision.datasets.FashionMNIST('./data', train=True, transform=transform, download=True)
validation_set = torchvision.datasets.FashionMNIST('./data', train=False, transform=transform, download=True)

# Create data loaders for our datasets; shuffle for training, not for validation
training_loader = torch.utils.data.DataLoader(training_set, batch_size=4, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=4, shuffle=False)

# Class labels
classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
        'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')

# Report split sizes
print('Training set has {} instances'.format(len(training_set)))
print('Validation set has {} instances'.format(len(validation_set)))
  0%|          | 0.00/26.4M [00:00<?, ?B/s]
  0%|          | 65.5k/26.4M [00:00<01:12, 362kB/s]
  1%|          | 229k/26.4M [00:00<00:38, 682kB/s]
  3%|▎         | 852k/26.4M [00:00<00:13, 1.94MB/s]
 13%|█▎        | 3.47M/26.4M [00:00<00:03, 6.88MB/s]
 36%|███▌      | 9.50M/26.4M [00:00<00:01, 16.4MB/s]
 59%|█████▊    | 15.5M/26.4M [00:01<00:00, 24.2MB/s]
 71%|███████   | 18.7M/26.4M [00:01<00:00, 23.9MB/s]
 92%|█████████▏| 24.4M/26.4M [00:01<00:00, 29.2MB/s]
100%|██████████| 26.4M/26.4M [00:01<00:00, 19.3MB/s]

  0%|          | 0.00/29.5k [00:00<?, ?B/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 325kB/s]

  0%|          | 0.00/4.42M [00:00<?, ?B/s]
  1%|▏         | 65.5k/4.42M [00:00<00:12, 359kB/s]
  5%|▌         | 229k/4.42M [00:00<00:06, 676kB/s]
 20%|██        | 885k/4.42M [00:00<00:01, 2.01MB/s]
 81%|████████  | 3.57M/4.42M [00:00<00:00, 7.02MB/s]
100%|██████████| 4.42M/4.42M [00:00<00:00, 6.04MB/s]

  0%|          | 0.00/5.15k [00:00<?, ?B/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 53.7MB/s]
Training set has 60000 instances
Validation set has 10000 instances

As always, let’s visualize the data as a sanity check:

import matplotlib.pyplot as plt
import numpy as np

# Helper function for inline image display
def matplotlib_imshow(img, one_channel=False):
    if one_channel:
        img = img.mean(dim=0)
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    if one_channel:
        plt.imshow(npimg, cmap="Greys")
    else:
        plt.imshow(np.transpose(npimg, (1, 2, 0)))

dataiter = iter(training_loader)
images, labels = next(dataiter)

# Create a grid from the images and show them
img_grid = torchvision.utils.make_grid(images)
matplotlib_imshow(img_grid, one_channel=True)
print('  '.join(classes[labels[j]] for j in range(4)))
trainingyt
Coat  Trouser  Ankle Boot  Coat

The Model#

The model we’ll use in this example is a variant of LeNet-5 - it should be familiar if you’ve watched the previous videos in this series.

import torch.nn as nn
import torch.nn.functional as F

# PyTorch models inherit from torch.nn.Module
class GarmentClassifier(nn.Module):
    def __init__(self):
        super(GarmentClassifier, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


model = GarmentClassifier()

Loss Function#

For this example, we’ll be using a cross-entropy loss. For demonstration purposes, we’ll create batches of dummy output and label values, run them through the loss function, and examine the result.

loss_fn = torch.nn.CrossEntropyLoss()

# NB: Loss functions expect data in batches, so we're creating batches of 4
# Represents the model's confidence in each of the 10 classes for a given input
dummy_outputs = torch.rand(4, 10)
# Represents the correct class among the 10 being tested
dummy_labels = torch.tensor([1, 5, 3, 7])

print(dummy_outputs)
print(dummy_labels)

loss = loss_fn(dummy_outputs, dummy_labels)
print('Total loss for this batch: {}'.format(loss.item()))
tensor([[0.2563, 0.2905, 0.7120, 0.2953, 0.4989, 0.2859, 0.9022, 0.5648, 0.8256,
         0.4616],
        [0.8982, 0.9235, 0.3355, 0.0948, 0.1403, 0.8536, 0.7452, 0.0116, 0.8484,
         0.1498],
        [0.5123, 0.6791, 0.4778, 0.8477, 0.6976, 0.3900, 0.1011, 0.5865, 0.9094,
         0.7380],
        [0.6704, 0.8565, 0.8854, 0.3825, 0.7305, 0.3692, 0.7638, 0.2193, 0.9502,
         0.5369]])
tensor([1, 5, 3, 7])
Total loss for this batch: 2.345141887664795

Optimizer#

For this example, we’ll be using simple stochastic gradient descent with momentum.

It can be instructive to try some variations on this optimization scheme:

  • Learning rate determines the size of the steps the optimizer takes. What does a different learning rate do to the your training results, in terms of accuracy and convergence time?

  • Momentum nudges the optimizer in the direction of strongest gradient over multiple steps. What does changing this value do to your results?

  • Try some different optimization algorithms, such as averaged SGD, Adagrad, or Adam. How do your results differ?

# Optimizers specified in the torch.optim package
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

The Training Loop#

Below, we have a function that performs one training epoch. It enumerates data from the DataLoader, and on each pass of the loop does the following:

  • Gets a batch of training data from the DataLoader

  • Zeros the optimizer’s gradients

  • Performs an inference - that is, gets predictions from the model for an input batch

  • Calculates the loss for that set of predictions vs. the labels on the dataset

  • Calculates the backward gradients over the learning weights

  • Tells the optimizer to perform one learning step - that is, adjust the model’s learning weights based on the observed gradients for this batch, according to the optimization algorithm we chose

  • It reports on the loss for every 1000 batches.

  • Finally, it reports the average per-batch loss for the last 1000 batches, for comparison with a validation run

def train_one_epoch(epoch_index, tb_writer):
    running_loss = 0.
    last_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(training_loader):
        # Every data instance is an input + label pair
        inputs, labels = data

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs = model(inputs)

        # Compute the loss and its gradients
        loss = loss_fn(outputs, labels)
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        if i % 1000 == 999:
            last_loss = running_loss / 1000 # loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            tb_x = epoch_index * len(training_loader) + i + 1
            tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            running_loss = 0.

    return last_loss

Per-Epoch Activity#

There are a couple of things we’ll want to do once per epoch:

  • Perform validation by checking our relative loss on a set of data that was not used for training, and report this

  • Save a copy of the model

Here, we’ll do our reporting in TensorBoard. This will require going to the command line to start TensorBoard, and opening it in another browser tab.

# Initializing in a separate cell so we can easily add more epochs to the same run
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp))
epoch_number = 0

EPOCHS = 5

best_vloss = 1_000_000.

for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch_number + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch(epoch_number, writer)


    running_vloss = 0.0
    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    model.eval()

    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        for i, vdata in enumerate(validation_loader):
            vinputs, vlabels = vdata
            voutputs = model(vinputs)
            vloss = loss_fn(voutputs, vlabels)
            running_vloss += vloss

    avg_vloss = running_vloss / (i + 1)
    print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

    # Log the running loss averaged per batch
    # for both training and validation
    writer.add_scalars('Training vs. Validation Loss',
                    { 'Training' : avg_loss, 'Validation' : avg_vloss },
                    epoch_number + 1)
    writer.flush()

    # Track best performance, and save the model's state
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = 'model_{}_{}'.format(timestamp, epoch_number)
        torch.save(model.state_dict(), model_path)

    epoch_number += 1
EPOCH 1:
  batch 1000 loss: 1.8323931244909764
  batch 2000 loss: 0.8784747361708433
  batch 3000 loss: 0.7262709785904735
  batch 4000 loss: 0.6738309391736984
  batch 5000 loss: 0.6461850374676287
  batch 6000 loss: 0.5776571166808717
  batch 7000 loss: 0.5584996137892595
  batch 8000 loss: 0.5492216537357308
  batch 9000 loss: 0.5117455316308187
  batch 10000 loss: 0.5246429578544339
  batch 11000 loss: 0.4964053811201593
  batch 12000 loss: 0.5001326295695034
  batch 13000 loss: 0.4669392162488075
  batch 14000 loss: 0.46964427373884243
  batch 15000 loss: 0.43630503810482335
LOSS train 0.43630503810482335 valid 0.47318997979164124
EPOCH 2:
  batch 1000 loss: 0.4435618306293036
  batch 2000 loss: 0.4290441724356933
  batch 3000 loss: 0.4071630984430085
  batch 4000 loss: 0.3956383166331216
  batch 5000 loss: 0.4035718485097168
  batch 6000 loss: 0.3835479527103598
  batch 7000 loss: 0.38598460171557963
  batch 8000 loss: 0.3859042554834159
  batch 9000 loss: 0.3874600419027556
  batch 10000 loss: 0.39659956550243075
  batch 11000 loss: 0.3810253945522127
  batch 12000 loss: 0.3628438891758851
  batch 13000 loss: 0.3503799811955687
  batch 14000 loss: 0.362846230824216
  batch 15000 loss: 0.3755143761200598
LOSS train 0.3755143761200598 valid 0.40298745036125183
EPOCH 3:
  batch 1000 loss: 0.3319629271739977
  batch 2000 loss: 0.350677063354
  batch 3000 loss: 0.35145215366657795
  batch 4000 loss: 0.3394931914190529
  batch 5000 loss: 0.3329608398898563
  batch 6000 loss: 0.34832883945610954
  batch 7000 loss: 0.322982593969311
  batch 8000 loss: 0.3265897408129531
  batch 9000 loss: 0.31376401647313107
  batch 10000 loss: 0.3289889779783116
  batch 11000 loss: 0.33349869607709115
  batch 12000 loss: 0.33560474293126025
  batch 13000 loss: 0.33189047793726784
  batch 14000 loss: 0.3270530575479061
  batch 15000 loss: 0.3293719403151117
LOSS train 0.3293719403151117 valid 0.34732159972190857
EPOCH 4:
  batch 1000 loss: 0.31064335546820077
  batch 2000 loss: 0.29299817238622017
  batch 3000 loss: 0.30615555192270405
  batch 4000 loss: 0.3068019174702495
  batch 5000 loss: 0.2935927447076392
  batch 6000 loss: 0.2904864349851996
  batch 7000 loss: 0.2914984790359886
  batch 8000 loss: 0.28802999942483076
  batch 9000 loss: 0.3160819323722535
  batch 10000 loss: 0.32195692529110237
  batch 11000 loss: 0.2989145797609963
  batch 12000 loss: 0.3091380840048096
  batch 13000 loss: 0.3153695368430781
  batch 14000 loss: 0.3140922244616449
  batch 15000 loss: 0.29219094637429227
LOSS train 0.29219094637429227 valid 0.32749319076538086
EPOCH 5:
  batch 1000 loss: 0.28640846343052545
  batch 2000 loss: 0.27627345865432107
  batch 3000 loss: 0.27327266449249143
  batch 4000 loss: 0.2758074446654464
  batch 5000 loss: 0.2918136065105209
  batch 6000 loss: 0.269755738459342
  batch 7000 loss: 0.28493136673758274
  batch 8000 loss: 0.27973395387151323
  batch 9000 loss: 0.26953814072352544
  batch 10000 loss: 0.2755987258453333
  batch 11000 loss: 0.28118035244659584
  batch 12000 loss: 0.3087851319604761
  batch 13000 loss: 0.288280862213378
  batch 14000 loss: 0.2984693810679091
  batch 15000 loss: 0.2996471112240906
LOSS train 0.2996471112240906 valid 0.307800829410553

To load a saved version of the model:

saved_model = GarmentClassifier()
saved_model.load_state_dict(torch.load(PATH))

Once you’ve loaded the model, it’s ready for whatever you need it for - more training, inference, or analysis.

Note that if your model has constructor parameters that affect model structure, you’ll need to provide them and configure the model identically to the state in which it was saved.

Other Resources#

Total running time of the script: (2 minutes 58.341 seconds)