Rate this Page

Introduction || Tensors || Autograd || Building Models || TensorBoard Support || Training Models || Model Understanding

Training with PyTorch#

Created On: Nov 30, 2021 | Last Updated: May 06, 2026 | Last Verified: Nov 05, 2024

Follow along with the video below or on youtube.

Introduction#

In past videos, we’ve discussed and demonstrated:

  • Building models with the neural network layers and functions of the torch.nn module

  • The mechanics of automated gradient computation, which is central to gradient-based model training

  • Using TensorBoard to visualize training progress and other activities

In this video, we’ll be adding some new tools to your inventory:

  • We’ll get familiar with the dataset and dataloader abstractions, and how they ease the process of feeding data to your model during a training loop

  • We’ll discuss specific loss functions and when to use them

  • We’ll look at PyTorch optimizers, which implement algorithms to adjust model weights based on the outcome of a loss function

Finally, we’ll pull all of these together and see a full PyTorch training loop in action.

Dataset and DataLoader#

The Dataset and DataLoader classes encapsulate the process of pulling your data from storage and exposing it to your training loop in batches.

The Dataset is responsible for accessing and processing single instances of data.

The DataLoader pulls instances of data from the Dataset (either automatically or with a sampler that you define), collects them in batches, and returns them for consumption by your training loop. The DataLoader works with all kinds of datasets, regardless of the type of data they contain.

For this tutorial, we’ll be using the Fashion-MNIST dataset provided by TorchVision. We use torchvision.transforms.v2.Normalize() to zero-center and normalize the distribution of the image tile content, and download both training and validation data splits.

import torch
import torchvision
from torchvision.transforms import v2

# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime


transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize((0.5,), (0.5,))
])

# Create datasets for training & validation, download if necessary
training_set = torchvision.datasets.FashionMNIST('./data', train=True, transform=transform, download=True)
validation_set = torchvision.datasets.FashionMNIST('./data', train=False, transform=transform, download=True)

# Create data loaders for our datasets; shuffle for training, not for validation
training_loader = torch.utils.data.DataLoader(training_set, batch_size=4, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=4, shuffle=False)

# Class labels
classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
        'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')

# Report split sizes
print(f'Training set has {len(training_set)} instances')
print(f'Validation set has {len(validation_set)} instances')
  0%|          | 0.00/26.4M [00:00<?, ?B/s]
  0%|          | 65.5k/26.4M [00:00<01:10, 375kB/s]
  1%|          | 229k/26.4M [00:00<00:37, 703kB/s]
  3%|▎         | 918k/26.4M [00:00<00:11, 2.17MB/s]
 14%|█▍        | 3.67M/26.4M [00:00<00:03, 7.50MB/s]
 37%|███▋      | 9.73M/26.4M [00:00<00:00, 17.2MB/s]
 59%|█████▉    | 15.6M/26.4M [00:01<00:00, 22.8MB/s]
 82%|████████▏ | 21.6M/26.4M [00:01<00:00, 26.4MB/s]
100%|██████████| 26.4M/26.4M [00:01<00:00, 20.0MB/s]

  0%|          | 0.00/29.5k [00:00<?, ?B/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 336kB/s]

  0%|          | 0.00/4.42M [00:00<?, ?B/s]
  1%|▏         | 65.5k/4.42M [00:00<00:11, 371kB/s]
  5%|▌         | 229k/4.42M [00:00<00:06, 698kB/s]
 21%|██        | 918k/4.42M [00:00<00:01, 2.16MB/s]
 83%|████████▎ | 3.67M/4.42M [00:00<00:00, 7.45MB/s]
100%|██████████| 4.42M/4.42M [00:00<00:00, 6.24MB/s]

  0%|          | 0.00/5.15k [00:00<?, ?B/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 55.1MB/s]
Training set has 60000 instances
Validation set has 10000 instances

As always, let’s visualize the data as a sanity check:

import matplotlib.pyplot as plt
import numpy as np

# Helper function for inline image display
def matplotlib_imshow(img, one_channel=False):
    if one_channel:
        img = img.mean(dim=0)
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    if one_channel:
        plt.imshow(npimg, cmap="Greys")
    else:
        plt.imshow(np.transpose(npimg, (1, 2, 0)))

dataiter = iter(training_loader)
images, labels = next(dataiter)

# Create a grid from the images and show them
img_grid = torchvision.utils.make_grid(images)
matplotlib_imshow(img_grid, one_channel=True)
print('  '.join(classes[labels[j]] for j in range(4)))
trainingyt
Ankle Boot  Pullover  Sandal  Bag

The Model#

The model we’ll use in this example is a variant of LeNet-5 - it should be familiar if you’ve watched the previous videos in this series.

import torch.nn as nn
import torch.nn.functional as F

# PyTorch models inherit from torch.nn.Module
class GarmentClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


model = GarmentClassifier()

Loss Function#

For this example, we’ll be using a cross-entropy loss. For demonstration purposes, we’ll create batches of dummy output and label values, run them through the loss function, and examine the result.

loss_fn = torch.nn.CrossEntropyLoss()

# NB: Loss functions expect data in batches, so we're creating batches of 4
# Represents the model's confidence in each of the 10 classes for a given input
dummy_outputs = torch.rand(4, 10)
# Represents the correct class among the 10 being tested
dummy_labels = torch.tensor([1, 5, 3, 7])

print(dummy_outputs)
print(dummy_labels)

loss = loss_fn(dummy_outputs, dummy_labels)
print(f'Total loss for this batch: {loss.item()}')
tensor([[0.3090, 0.8950, 0.6749, 0.6726, 0.8422, 0.3280, 0.1156, 0.6951, 0.8365,
         0.4468],
        [0.4872, 0.4600, 0.1971, 0.2012, 0.3649, 0.2560, 0.8843, 0.7869, 0.5449,
         0.6888],
        [0.5280, 0.3079, 0.8036, 0.3795, 0.6707, 0.8581, 0.8378, 0.0882, 0.2185,
         0.9666],
        [0.2532, 0.5359, 0.8772, 0.2511, 0.0036, 0.7973, 0.7736, 0.1425, 0.4803,
         0.3753]])
tensor([1, 5, 3, 7])
Total loss for this batch: 2.4396369457244873

Optimizer#

For this example, we’ll be using simple stochastic gradient descent with momentum.

It can be instructive to try some variations on this optimization scheme:

  • Learning rate determines the size of the steps the optimizer takes. What does a different learning rate do to the your training results, in terms of accuracy and convergence time?

  • Momentum nudges the optimizer in the direction of strongest gradient over multiple steps. What does changing this value do to your results?

  • Try some different optimization algorithms, such as averaged SGD, Adagrad, or Adam. How do your results differ?

# Optimizers specified in the torch.optim package
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

The Training Loop#

Below, we have a function that performs one training epoch. It enumerates data from the DataLoader, and on each pass of the loop does the following:

  • Gets a batch of training data from the DataLoader

  • Zeros the optimizer’s gradients

  • Performs an inference - that is, gets predictions from the model for an input batch

  • Calculates the loss for that set of predictions vs. the labels on the dataset

  • Calculates the backward gradients over the learning weights

  • Tells the optimizer to perform one learning step - that is, adjust the model’s learning weights based on the observed gradients for this batch, according to the optimization algorithm we chose

  • It reports on the loss for every 1000 batches.

  • Finally, it reports the average per-batch loss for the last 1000 batches, for comparison with a validation run

def train_one_epoch(epoch_index, tb_writer):
    running_loss = 0.
    last_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(training_loader):
        # Every data instance is an input + label pair
        inputs, labels = data

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs = model(inputs)

        # Compute the loss and its gradients
        loss = loss_fn(outputs, labels)
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        if i % 1000 == 999:
            last_loss = running_loss / 1000 # loss per batch
            print(f'  batch {i + 1} loss: {last_loss}')
            tb_x = epoch_index * len(training_loader) + i + 1
            tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            running_loss = 0.

    return last_loss

Per-Epoch Activity#

There are a couple of things we’ll want to do once per epoch:

  • Perform validation by checking our relative loss on a set of data that was not used for training, and report this

  • Save a copy of the model

Here, we’ll do our reporting in TensorBoard. This will require going to the command line to start TensorBoard, and opening it in another browser tab.

# Initializing in a separate cell so we can easily add more epochs to the same run
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter(f'runs/fashion_trainer_{timestamp}')
epoch_number = 0

EPOCHS = 5

best_vloss = 1_000_000.

for epoch in range(EPOCHS):
    print(f'EPOCH {epoch_number + 1}:')

    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch(epoch_number, writer)


    running_vloss = 0.0
    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    model.eval()

    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        for i, vdata in enumerate(validation_loader):
            vinputs, vlabels = vdata
            voutputs = model(vinputs)
            vloss = loss_fn(voutputs, vlabels)
            running_vloss += vloss

    avg_vloss = running_vloss / (i + 1)
    print(f'LOSS train {avg_loss} valid {avg_vloss}')

    # Log the running loss averaged per batch
    # for both training and validation
    writer.add_scalars('Training vs. Validation Loss',
                    { 'Training' : avg_loss, 'Validation' : avg_vloss },
                    epoch_number + 1)
    writer.flush()

    # Track best performance, and save the model's state
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = f'model_{timestamp}_{epoch_number}'
        torch.save(model.state_dict(), model_path)

    epoch_number += 1
EPOCH 1:
  batch 1000 loss: 1.7137023788839578
  batch 2000 loss: 0.7973894174164161
  batch 3000 loss: 0.7231586747355759
  batch 4000 loss: 0.6248975814115256
  batch 5000 loss: 0.5981162405619398
  batch 6000 loss: 0.5691179783265107
  batch 7000 loss: 0.5275342330737621
  batch 8000 loss: 0.5013267770194215
  batch 9000 loss: 0.49572225263924335
  batch 10000 loss: 0.4725613957467722
  batch 11000 loss: 0.46171742442133834
  batch 12000 loss: 0.4368839423506579
  batch 13000 loss: 0.4545344554230105
  batch 14000 loss: 0.43802149582689165
  batch 15000 loss: 0.40149797978081914
LOSS train 0.40149797978081914 valid 0.4494331181049347
EPOCH 2:
  batch 1000 loss: 0.41104386712331326
  batch 2000 loss: 0.393669776138122
  batch 3000 loss: 0.3897873696521274
  batch 4000 loss: 0.38900074754923114
  batch 5000 loss: 0.38853561899089256
  batch 6000 loss: 0.37409437633503695
  batch 7000 loss: 0.3779398200036958
  batch 8000 loss: 0.3722098283530213
  batch 9000 loss: 0.3666050886940211
  batch 10000 loss: 0.3467877195001056
  batch 11000 loss: 0.3680948962146649
  batch 12000 loss: 0.35175823248620147
  batch 13000 loss: 0.3316007589644869
  batch 14000 loss: 0.35571451572725343
  batch 15000 loss: 0.3588475529536954
LOSS train 0.3588475529536954 valid 0.3705717623233795
EPOCH 3:
  batch 1000 loss: 0.3448412865669379
  batch 2000 loss: 0.336657260659209
  batch 3000 loss: 0.34794902718451337
  batch 4000 loss: 0.31277311360003657
  batch 5000 loss: 0.3354689367398896
  batch 6000 loss: 0.3160987413919065
  batch 7000 loss: 0.3163412383149625
  batch 8000 loss: 0.32091431439724694
  batch 9000 loss: 0.31169858617527646
  batch 10000 loss: 0.2994646521287359
  batch 11000 loss: 0.32518774883786683
  batch 12000 loss: 0.3148824741213175
  batch 13000 loss: 0.3024750067451096
  batch 14000 loss: 0.31182337841358093
  batch 15000 loss: 0.32070566362967656
LOSS train 0.32070566362967656 valid 0.3499591648578644
EPOCH 4:
  batch 1000 loss: 0.29151718638937746
  batch 2000 loss: 0.27663412005603094
  batch 3000 loss: 0.30008966725850766
  batch 4000 loss: 0.3025729545346403
  batch 5000 loss: 0.3001110927646623
  batch 6000 loss: 0.29214665762001824
  batch 7000 loss: 0.27628296986312123
  batch 8000 loss: 0.29913593605219646
  batch 9000 loss: 0.3108261799438042
  batch 10000 loss: 0.3093763867207599
  batch 11000 loss: 0.3000419056410974
  batch 12000 loss: 0.28552318386983827
  batch 13000 loss: 0.29356705640597286
  batch 14000 loss: 0.29047300382776303
  batch 15000 loss: 0.29394569155937644
LOSS train 0.29394569155937644 valid 0.3427586555480957
EPOCH 5:
  batch 1000 loss: 0.26650003100578035
  batch 2000 loss: 0.2695607030618539
  batch 3000 loss: 0.2745289891322936
  batch 4000 loss: 0.27267806035420655
  batch 5000 loss: 0.2638148403369123
  batch 6000 loss: 0.3078928508891695
  batch 7000 loss: 0.2678979963034908
  batch 8000 loss: 0.29600719419810045
  batch 9000 loss: 0.27045275746281183
  batch 10000 loss: 0.2714199966633696
  batch 11000 loss: 0.2717091838987835
  batch 12000 loss: 0.2736941087890068
  batch 13000 loss: 0.2599813518770698
  batch 14000 loss: 0.26593902035593053
  batch 15000 loss: 0.2868527134830292
LOSS train 0.2868527134830292 valid 0.30729931592941284

To load a saved version of the model:

saved_model = GarmentClassifier()
saved_model.load_state_dict(torch.load(PATH))

Once you’ve loaded the model, it’s ready for whatever you need it for - more training, inference, or analysis.

Note that if your model has constructor parameters that affect model structure, you’ll need to provide them and configure the model identically to the state in which it was saved.

Other Resources#

Total running time of the script: (3 minutes 20.604 seconds)