Rate this Page

Introduction || Tensors || Autograd || Building Models || TensorBoard Support || Training Models || Model Understanding

Training with PyTorch#

Created On: Nov 30, 2021 | Last Updated: May 31, 2023 | Last Verified: Nov 05, 2024

Follow along with the video below or on youtube.

Introduction#

In past videos, we’ve discussed and demonstrated:

  • Building models with the neural network layers and functions of the torch.nn module

  • The mechanics of automated gradient computation, which is central to gradient-based model training

  • Using TensorBoard to visualize training progress and other activities

In this video, we’ll be adding some new tools to your inventory:

  • We’ll get familiar with the dataset and dataloader abstractions, and how they ease the process of feeding data to your model during a training loop

  • We’ll discuss specific loss functions and when to use them

  • We’ll look at PyTorch optimizers, which implement algorithms to adjust model weights based on the outcome of a loss function

Finally, we’ll pull all of these together and see a full PyTorch training loop in action.

Dataset and DataLoader#

The Dataset and DataLoader classes encapsulate the process of pulling your data from storage and exposing it to your training loop in batches.

The Dataset is responsible for accessing and processing single instances of data.

The DataLoader pulls instances of data from the Dataset (either automatically or with a sampler that you define), collects them in batches, and returns them for consumption by your training loop. The DataLoader works with all kinds of datasets, regardless of the type of data they contain.

For this tutorial, we’ll be using the Fashion-MNIST dataset provided by TorchVision. We use torchvision.transforms.Normalize() to zero-center and normalize the distribution of the image tile content, and download both training and validation data splits.

import torch
import torchvision
import torchvision.transforms as transforms

# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime


transform = transforms.Compose(
    [transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))])

# Create datasets for training & validation, download if necessary
training_set = torchvision.datasets.FashionMNIST('./data', train=True, transform=transform, download=True)
validation_set = torchvision.datasets.FashionMNIST('./data', train=False, transform=transform, download=True)

# Create data loaders for our datasets; shuffle for training, not for validation
training_loader = torch.utils.data.DataLoader(training_set, batch_size=4, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=4, shuffle=False)

# Class labels
classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
        'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')

# Report split sizes
print('Training set has {} instances'.format(len(training_set)))
print('Validation set has {} instances'.format(len(validation_set)))
  0%|          | 0.00/26.4M [00:00<?, ?B/s]
  0%|          | 65.5k/26.4M [00:00<01:09, 379kB/s]
  1%|          | 197k/26.4M [00:00<00:43, 600kB/s]
  3%|▎         | 852k/26.4M [00:00<00:12, 2.04MB/s]
 13%|█▎        | 3.38M/26.4M [00:00<00:03, 6.96MB/s]
 35%|███▌      | 9.31M/26.4M [00:00<00:01, 16.7MB/s]
 58%|█████▊    | 15.3M/26.4M [00:01<00:00, 22.7MB/s]
 80%|███████▉  | 21.0M/26.4M [00:01<00:00, 26.1MB/s]
100%|██████████| 26.4M/26.4M [00:01<00:00, 20.1MB/s]

  0%|          | 0.00/29.5k [00:00<?, ?B/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 338kB/s]

  0%|          | 0.00/4.42M [00:00<?, ?B/s]
  1%|▏         | 65.5k/4.42M [00:00<00:11, 377kB/s]
  5%|▌         | 229k/4.42M [00:00<00:05, 709kB/s]
 21%|██        | 918k/4.42M [00:00<00:01, 2.19MB/s]
 83%|████████▎ | 3.67M/4.42M [00:00<00:00, 7.57MB/s]
100%|██████████| 4.42M/4.42M [00:00<00:00, 6.34MB/s]

  0%|          | 0.00/5.15k [00:00<?, ?B/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 56.7MB/s]
Training set has 60000 instances
Validation set has 10000 instances

As always, let’s visualize the data as a sanity check:

import matplotlib.pyplot as plt
import numpy as np

# Helper function for inline image display
def matplotlib_imshow(img, one_channel=False):
    if one_channel:
        img = img.mean(dim=0)
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    if one_channel:
        plt.imshow(npimg, cmap="Greys")
    else:
        plt.imshow(np.transpose(npimg, (1, 2, 0)))

dataiter = iter(training_loader)
images, labels = next(dataiter)

# Create a grid from the images and show them
img_grid = torchvision.utils.make_grid(images)
matplotlib_imshow(img_grid, one_channel=True)
print('  '.join(classes[labels[j]] for j in range(4)))
trainingyt
T-shirt/top  Coat  Dress  Dress

The Model#

The model we’ll use in this example is a variant of LeNet-5 - it should be familiar if you’ve watched the previous videos in this series.

import torch.nn as nn
import torch.nn.functional as F

# PyTorch models inherit from torch.nn.Module
class GarmentClassifier(nn.Module):
    def __init__(self):
        super(GarmentClassifier, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


model = GarmentClassifier()

Loss Function#

For this example, we’ll be using a cross-entropy loss. For demonstration purposes, we’ll create batches of dummy output and label values, run them through the loss function, and examine the result.

loss_fn = torch.nn.CrossEntropyLoss()

# NB: Loss functions expect data in batches, so we're creating batches of 4
# Represents the model's confidence in each of the 10 classes for a given input
dummy_outputs = torch.rand(4, 10)
# Represents the correct class among the 10 being tested
dummy_labels = torch.tensor([1, 5, 3, 7])

print(dummy_outputs)
print(dummy_labels)

loss = loss_fn(dummy_outputs, dummy_labels)
print('Total loss for this batch: {}'.format(loss.item()))
tensor([[0.7148, 0.8898, 0.1172, 0.2066, 0.9920, 0.8607, 0.7535, 0.8473, 0.1322,
         0.1747],
        [0.1036, 0.4382, 0.2070, 0.2487, 0.0364, 0.4221, 0.6340, 0.3945, 0.9411,
         0.8449],
        [0.6413, 0.3273, 0.4183, 0.8485, 0.0809, 0.1837, 0.7375, 0.5821, 0.1590,
         0.3847],
        [0.8578, 0.4064, 0.0648, 0.4313, 0.9652, 0.3780, 0.8934, 0.2501, 0.6263,
         0.8594]])
tensor([1, 5, 3, 7])
Total loss for this batch: 2.2442619800567627

Optimizer#

For this example, we’ll be using simple stochastic gradient descent with momentum.

It can be instructive to try some variations on this optimization scheme:

  • Learning rate determines the size of the steps the optimizer takes. What does a different learning rate do to the your training results, in terms of accuracy and convergence time?

  • Momentum nudges the optimizer in the direction of strongest gradient over multiple steps. What does changing this value do to your results?

  • Try some different optimization algorithms, such as averaged SGD, Adagrad, or Adam. How do your results differ?

# Optimizers specified in the torch.optim package
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

The Training Loop#

Below, we have a function that performs one training epoch. It enumerates data from the DataLoader, and on each pass of the loop does the following:

  • Gets a batch of training data from the DataLoader

  • Zeros the optimizer’s gradients

  • Performs an inference - that is, gets predictions from the model for an input batch

  • Calculates the loss for that set of predictions vs. the labels on the dataset

  • Calculates the backward gradients over the learning weights

  • Tells the optimizer to perform one learning step - that is, adjust the model’s learning weights based on the observed gradients for this batch, according to the optimization algorithm we chose

  • It reports on the loss for every 1000 batches.

  • Finally, it reports the average per-batch loss for the last 1000 batches, for comparison with a validation run

def train_one_epoch(epoch_index, tb_writer):
    running_loss = 0.
    last_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(training_loader):
        # Every data instance is an input + label pair
        inputs, labels = data

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs = model(inputs)

        # Compute the loss and its gradients
        loss = loss_fn(outputs, labels)
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        if i % 1000 == 999:
            last_loss = running_loss / 1000 # loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            tb_x = epoch_index * len(training_loader) + i + 1
            tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            running_loss = 0.

    return last_loss

Per-Epoch Activity#

There are a couple of things we’ll want to do once per epoch:

  • Perform validation by checking our relative loss on a set of data that was not used for training, and report this

  • Save a copy of the model

Here, we’ll do our reporting in TensorBoard. This will require going to the command line to start TensorBoard, and opening it in another browser tab.

# Initializing in a separate cell so we can easily add more epochs to the same run
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp))
epoch_number = 0

EPOCHS = 5

best_vloss = 1_000_000.

for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch_number + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch(epoch_number, writer)


    running_vloss = 0.0
    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    model.eval()

    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        for i, vdata in enumerate(validation_loader):
            vinputs, vlabels = vdata
            voutputs = model(vinputs)
            vloss = loss_fn(voutputs, vlabels)
            running_vloss += vloss

    avg_vloss = running_vloss / (i + 1)
    print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

    # Log the running loss averaged per batch
    # for both training and validation
    writer.add_scalars('Training vs. Validation Loss',
                    { 'Training' : avg_loss, 'Validation' : avg_vloss },
                    epoch_number + 1)
    writer.flush()

    # Track best performance, and save the model's state
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = 'model_{}_{}'.format(timestamp, epoch_number)
        torch.save(model.state_dict(), model_path)

    epoch_number += 1
EPOCH 1:
  batch 1000 loss: 1.7953810837864876
  batch 2000 loss: 0.8244951499085873
  batch 3000 loss: 0.6795841237778076
  batch 4000 loss: 0.6237406253132504
  batch 5000 loss: 0.5924521059104009
  batch 6000 loss: 0.5446553561121691
  batch 7000 loss: 0.5328865849536378
  batch 8000 loss: 0.5121419241574476
  batch 9000 loss: 0.4841414849520079
  batch 10000 loss: 0.4666176200763439
  batch 11000 loss: 0.46205400900312815
  batch 12000 loss: 0.4447007388505153
  batch 13000 loss: 0.4142924808491953
  batch 14000 loss: 0.40956779026566076
  batch 15000 loss: 0.4177117152803985
LOSS train 0.4177117152803985 valid 0.44553622603416443
EPOCH 2:
  batch 1000 loss: 0.40128410009422805
  batch 2000 loss: 0.40305546496622263
  batch 3000 loss: 0.40702980543574085
  batch 4000 loss: 0.39407327451440505
  batch 5000 loss: 0.36352190978778526
  batch 6000 loss: 0.37312547214997177
  batch 7000 loss: 0.3758366793676978
  batch 8000 loss: 0.3405639385653776
  batch 9000 loss: 0.36997920134241574
  batch 10000 loss: 0.359413162092882
  batch 11000 loss: 0.37312915771725236
  batch 12000 loss: 0.33811824521305606
  batch 13000 loss: 0.3572969840469887
  batch 14000 loss: 0.36620103280519833
  batch 15000 loss: 0.3213143042664742
LOSS train 0.3213143042664742 valid 0.3517323434352875
EPOCH 3:
  batch 1000 loss: 0.3337490362423123
  batch 2000 loss: 0.3311599798273237
  batch 3000 loss: 0.33811903376392727
  batch 4000 loss: 0.32846899344894337
  batch 5000 loss: 0.31784062373687627
  batch 6000 loss: 0.3259422236883838
  batch 7000 loss: 0.3189901710792528
  batch 8000 loss: 0.32540348729457763
  batch 9000 loss: 0.32341072909034846
  batch 10000 loss: 0.3318939990024701
  batch 11000 loss: 0.33045886283699655
  batch 12000 loss: 0.32612536531797376
  batch 13000 loss: 0.3335226553216344
  batch 14000 loss: 0.31242431101092005
  batch 15000 loss: 0.3000959627712
LOSS train 0.3000959627712 valid 0.3497053384780884
EPOCH 4:
  batch 1000 loss: 0.29982622844899015
  batch 2000 loss: 0.28241306077288025
  batch 3000 loss: 0.3224020686237054
  batch 4000 loss: 0.2933463651681086
  batch 5000 loss: 0.3013071577974042
  batch 6000 loss: 0.2969221141233502
  batch 7000 loss: 0.2895930815332831
  batch 8000 loss: 0.31016799209580587
  batch 9000 loss: 0.3052861173714773
  batch 10000 loss: 0.30825613315279043
  batch 11000 loss: 0.29595888312389435
  batch 12000 loss: 0.29418928824269097
  batch 13000 loss: 0.30175805846647563
  batch 14000 loss: 0.29366780563646533
  batch 15000 loss: 0.3100094365816913
LOSS train 0.3100094365816913 valid 0.3187921345233917
EPOCH 5:
  batch 1000 loss: 0.27892607753528864
  batch 2000 loss: 0.26567523849471764
  batch 3000 loss: 0.2855535473689015
  batch 4000 loss: 0.29075567252633483
  batch 5000 loss: 0.274320784878033
  batch 6000 loss: 0.2866456297454788
  batch 7000 loss: 0.2875749778726167
  batch 8000 loss: 0.28006844768504197
  batch 9000 loss: 0.27494829583771935
  batch 10000 loss: 0.2901083213442616
  batch 11000 loss: 0.2757640323975102
  batch 12000 loss: 0.2684062098810973
  batch 13000 loss: 0.2902034229126348
  batch 14000 loss: 0.2818407698779483
  batch 15000 loss: 0.27735993386270275
LOSS train 0.27735993386270275 valid 0.3038559556007385

To load a saved version of the model:

saved_model = GarmentClassifier()
saved_model.load_state_dict(torch.load(PATH))

Once you’ve loaded the model, it’s ready for whatever you need it for - more training, inference, or analysis.

Note that if your model has constructor parameters that affect model structure, you’ll need to provide them and configure the model identically to the state in which it was saved.

Other Resources#

Total running time of the script: (2 minutes 59.346 seconds)