Rate this Page

Introduction || Tensors || Autograd || Building Models || TensorBoard Support || Training Models || Model Understanding

Training with PyTorch#

Created On: Nov 30, 2021 | Last Updated: May 31, 2023 | Last Verified: Nov 05, 2024

Follow along with the video below or on youtube.

Introduction#

In past videos, we’ve discussed and demonstrated:

  • Building models with the neural network layers and functions of the torch.nn module

  • The mechanics of automated gradient computation, which is central to gradient-based model training

  • Using TensorBoard to visualize training progress and other activities

In this video, we’ll be adding some new tools to your inventory:

  • We’ll get familiar with the dataset and dataloader abstractions, and how they ease the process of feeding data to your model during a training loop

  • We’ll discuss specific loss functions and when to use them

  • We’ll look at PyTorch optimizers, which implement algorithms to adjust model weights based on the outcome of a loss function

Finally, we’ll pull all of these together and see a full PyTorch training loop in action.

Dataset and DataLoader#

The Dataset and DataLoader classes encapsulate the process of pulling your data from storage and exposing it to your training loop in batches.

The Dataset is responsible for accessing and processing single instances of data.

The DataLoader pulls instances of data from the Dataset (either automatically or with a sampler that you define), collects them in batches, and returns them for consumption by your training loop. The DataLoader works with all kinds of datasets, regardless of the type of data they contain.

For this tutorial, we’ll be using the Fashion-MNIST dataset provided by TorchVision. We use torchvision.transforms.Normalize() to zero-center and normalize the distribution of the image tile content, and download both training and validation data splits.

import torch
import torchvision
import torchvision.transforms as transforms

# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime


transform = transforms.Compose(
    [transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))])

# Create datasets for training & validation, download if necessary
training_set = torchvision.datasets.FashionMNIST('./data', train=True, transform=transform, download=True)
validation_set = torchvision.datasets.FashionMNIST('./data', train=False, transform=transform, download=True)

# Create data loaders for our datasets; shuffle for training, not for validation
training_loader = torch.utils.data.DataLoader(training_set, batch_size=4, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=4, shuffle=False)

# Class labels
classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
        'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')

# Report split sizes
print('Training set has {} instances'.format(len(training_set)))
print('Validation set has {} instances'.format(len(validation_set)))
  0%|          | 0.00/26.4M [00:00<?, ?B/s]
  0%|          | 65.5k/26.4M [00:00<01:12, 364kB/s]
  1%|          | 229k/26.4M [00:00<00:38, 683kB/s]
  3%|▎         | 918k/26.4M [00:00<00:12, 2.11MB/s]
 14%|█▍        | 3.67M/26.4M [00:00<00:03, 7.28MB/s]
 37%|███▋      | 9.76M/26.4M [00:00<00:00, 16.8MB/s]
 60%|█████▉    | 15.8M/26.4M [00:01<00:00, 22.3MB/s]
 82%|████████▏ | 21.8M/26.4M [00:01<00:00, 25.9MB/s]
100%|██████████| 26.4M/26.4M [00:01<00:00, 19.4MB/s]

  0%|          | 0.00/29.5k [00:00<?, ?B/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 326kB/s]

  0%|          | 0.00/4.42M [00:00<?, ?B/s]
  1%|▏         | 65.5k/4.42M [00:00<00:12, 359kB/s]
  5%|▌         | 229k/4.42M [00:00<00:06, 674kB/s]
 20%|██        | 885k/4.42M [00:00<00:01, 2.01MB/s]
 79%|███████▊  | 3.47M/4.42M [00:00<00:00, 6.83MB/s]
100%|██████████| 4.42M/4.42M [00:00<00:00, 6.05MB/s]

  0%|          | 0.00/5.15k [00:00<?, ?B/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 60.5MB/s]
Training set has 60000 instances
Validation set has 10000 instances

As always, let’s visualize the data as a sanity check:

import matplotlib.pyplot as plt
import numpy as np

# Helper function for inline image display
def matplotlib_imshow(img, one_channel=False):
    if one_channel:
        img = img.mean(dim=0)
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    if one_channel:
        plt.imshow(npimg, cmap="Greys")
    else:
        plt.imshow(np.transpose(npimg, (1, 2, 0)))

dataiter = iter(training_loader)
images, labels = next(dataiter)

# Create a grid from the images and show them
img_grid = torchvision.utils.make_grid(images)
matplotlib_imshow(img_grid, one_channel=True)
print('  '.join(classes[labels[j]] for j in range(4)))
trainingyt
Sneaker  Coat  Bag  Shirt

The Model#

The model we’ll use in this example is a variant of LeNet-5 - it should be familiar if you’ve watched the previous videos in this series.

import torch.nn as nn
import torch.nn.functional as F

# PyTorch models inherit from torch.nn.Module
class GarmentClassifier(nn.Module):
    def __init__(self):
        super(GarmentClassifier, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


model = GarmentClassifier()

Loss Function#

For this example, we’ll be using a cross-entropy loss. For demonstration purposes, we’ll create batches of dummy output and label values, run them through the loss function, and examine the result.

loss_fn = torch.nn.CrossEntropyLoss()

# NB: Loss functions expect data in batches, so we're creating batches of 4
# Represents the model's confidence in each of the 10 classes for a given input
dummy_outputs = torch.rand(4, 10)
# Represents the correct class among the 10 being tested
dummy_labels = torch.tensor([1, 5, 3, 7])

print(dummy_outputs)
print(dummy_labels)

loss = loss_fn(dummy_outputs, dummy_labels)
print('Total loss for this batch: {}'.format(loss.item()))
tensor([[0.8143, 0.3530, 0.8622, 0.4809, 0.6944, 0.4542, 0.1473, 0.8203, 0.6816,
         0.9211],
        [0.1781, 0.1462, 0.0238, 0.9874, 0.0519, 0.6630, 0.0294, 0.8929, 0.5338,
         0.2547],
        [0.4890, 0.7959, 0.7637, 0.4780, 0.4929, 0.9779, 0.4672, 0.6508, 0.4998,
         0.3285],
        [0.1643, 0.1706, 0.1881, 0.9518, 0.6545, 0.8956, 0.1230, 0.0336, 0.9063,
         0.6121]])
tensor([1, 5, 3, 7])
Total loss for this batch: 2.479302406311035

Optimizer#

For this example, we’ll be using simple stochastic gradient descent with momentum.

It can be instructive to try some variations on this optimization scheme:

  • Learning rate determines the size of the steps the optimizer takes. What does a different learning rate do to the your training results, in terms of accuracy and convergence time?

  • Momentum nudges the optimizer in the direction of strongest gradient over multiple steps. What does changing this value do to your results?

  • Try some different optimization algorithms, such as averaged SGD, Adagrad, or Adam. How do your results differ?

# Optimizers specified in the torch.optim package
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

The Training Loop#

Below, we have a function that performs one training epoch. It enumerates data from the DataLoader, and on each pass of the loop does the following:

  • Gets a batch of training data from the DataLoader

  • Zeros the optimizer’s gradients

  • Performs an inference - that is, gets predictions from the model for an input batch

  • Calculates the loss for that set of predictions vs. the labels on the dataset

  • Calculates the backward gradients over the learning weights

  • Tells the optimizer to perform one learning step - that is, adjust the model’s learning weights based on the observed gradients for this batch, according to the optimization algorithm we chose

  • It reports on the loss for every 1000 batches.

  • Finally, it reports the average per-batch loss for the last 1000 batches, for comparison with a validation run

def train_one_epoch(epoch_index, tb_writer):
    running_loss = 0.
    last_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(training_loader):
        # Every data instance is an input + label pair
        inputs, labels = data

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs = model(inputs)

        # Compute the loss and its gradients
        loss = loss_fn(outputs, labels)
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        if i % 1000 == 999:
            last_loss = running_loss / 1000 # loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            tb_x = epoch_index * len(training_loader) + i + 1
            tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            running_loss = 0.

    return last_loss

Per-Epoch Activity#

There are a couple of things we’ll want to do once per epoch:

  • Perform validation by checking our relative loss on a set of data that was not used for training, and report this

  • Save a copy of the model

Here, we’ll do our reporting in TensorBoard. This will require going to the command line to start TensorBoard, and opening it in another browser tab.

# Initializing in a separate cell so we can easily add more epochs to the same run
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp))
epoch_number = 0

EPOCHS = 5

best_vloss = 1_000_000.

for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch_number + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch(epoch_number, writer)


    running_vloss = 0.0
    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    model.eval()

    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        for i, vdata in enumerate(validation_loader):
            vinputs, vlabels = vdata
            voutputs = model(vinputs)
            vloss = loss_fn(voutputs, vlabels)
            running_vloss += vloss

    avg_vloss = running_vloss / (i + 1)
    print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

    # Log the running loss averaged per batch
    # for both training and validation
    writer.add_scalars('Training vs. Validation Loss',
                    { 'Training' : avg_loss, 'Validation' : avg_vloss },
                    epoch_number + 1)
    writer.flush()

    # Track best performance, and save the model's state
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = 'model_{}_{}'.format(timestamp, epoch_number)
        torch.save(model.state_dict(), model_path)

    epoch_number += 1
EPOCH 1:
  batch 1000 loss: 1.948841470271349
  batch 2000 loss: 0.8225232843509875
  batch 3000 loss: 0.6810694508682936
  batch 4000 loss: 0.6374944493332878
  batch 5000 loss: 0.593948904428631
  batch 6000 loss: 0.541161258992739
  batch 7000 loss: 0.5084875992755405
  batch 8000 loss: 0.4883952557892771
  batch 9000 loss: 0.5099732027912978
  batch 10000 loss: 0.45154825277545024
  batch 11000 loss: 0.45277360260137356
  batch 12000 loss: 0.4114964400313329
  batch 13000 loss: 0.43476405190851075
  batch 14000 loss: 0.4162119897347875
  batch 15000 loss: 0.4168772909468971
LOSS train 0.4168772909468971 valid 0.4692341685295105
EPOCH 2:
  batch 1000 loss: 0.39596002114668954
  batch 2000 loss: 0.38436802294070366
  batch 3000 loss: 0.38804610444721765
  batch 4000 loss: 0.37987252774683294
  batch 5000 loss: 0.37978977441355527
  batch 6000 loss: 0.36946967538970055
  batch 7000 loss: 0.37562350912688997
  batch 8000 loss: 0.36733832088003693
  batch 9000 loss: 0.3422593215914676
  batch 10000 loss: 0.33464168290619273
  batch 11000 loss: 0.3632593385106011
  batch 12000 loss: 0.34478370949701637
  batch 13000 loss: 0.34918551561175265
  batch 14000 loss: 0.34863743619757587
  batch 15000 loss: 0.36068782084216944
LOSS train 0.36068782084216944 valid 0.3662540912628174
EPOCH 3:
  batch 1000 loss: 0.34055076682037905
  batch 2000 loss: 0.32038137082473256
  batch 3000 loss: 0.31732052425210716
  batch 4000 loss: 0.3390816284599132
  batch 5000 loss: 0.3119717678526067
  batch 6000 loss: 0.315720810723491
  batch 7000 loss: 0.33046770557735
  batch 8000 loss: 0.2952491888772056
  batch 9000 loss: 0.319138653222908
  batch 10000 loss: 0.3436234647281162
  batch 11000 loss: 0.2984337019044615
  batch 12000 loss: 0.32183488384692466
  batch 13000 loss: 0.2911853743920219
  batch 14000 loss: 0.32470593122900754
  batch 15000 loss: 0.3259980706569913
LOSS train 0.3259980706569913 valid 0.3089425563812256
EPOCH 4:
  batch 1000 loss: 0.2980323342512638
  batch 2000 loss: 0.2936071463074841
  batch 3000 loss: 0.2917044203496189
  batch 4000 loss: 0.27742919136660205
  batch 5000 loss: 0.2807973468512937
  batch 6000 loss: 0.3000487627111579
  batch 7000 loss: 0.2938062994651791
  batch 8000 loss: 0.27784532450180266
  batch 9000 loss: 0.2922244194418963
  batch 10000 loss: 0.2948183331193068
  batch 11000 loss: 0.30225862563654665
  batch 12000 loss: 0.3003590454256337
  batch 13000 loss: 0.2967724909290264
  batch 14000 loss: 0.2687929000802469
  batch 15000 loss: 0.2961642937959477
LOSS train 0.2961642937959477 valid 0.3145768642425537
EPOCH 5:
  batch 1000 loss: 0.27838293520484875
  batch 2000 loss: 0.2612732032280655
  batch 3000 loss: 0.2734130588883854
  batch 4000 loss: 0.26696810725982506
  batch 5000 loss: 0.2683306515677668
  batch 6000 loss: 0.2705223727729499
  batch 7000 loss: 0.2726025758356009
  batch 8000 loss: 0.27087046978988655
  batch 9000 loss: 0.26028087261202015
  batch 10000 loss: 0.30320248483981777
  batch 11000 loss: 0.2667968332106984
  batch 12000 loss: 0.28255444055971746
  batch 13000 loss: 0.25882052225403823
  batch 14000 loss: 0.27356255225187487
  batch 15000 loss: 0.27037248696101596
LOSS train 0.27037248696101596 valid 0.31505483388900757

To load a saved version of the model:

saved_model = GarmentClassifier()
saved_model.load_state_dict(torch.load(PATH))

Once you’ve loaded the model, it’s ready for whatever you need it for - more training, inference, or analysis.

Note that if your model has constructor parameters that affect model structure, you’ll need to provide them and configure the model identically to the state in which it was saved.

Other Resources#

Total running time of the script: (3 minutes 2.065 seconds)