.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "beginner/basics/quickstart_tutorial.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_beginner_basics_quickstart_tutorial.py: `Learn the Basics `_ || **Quickstart** || `Tensors `_ || `Datasets & DataLoaders `_ || `Transforms `_ || `Build Model `_ || `Autograd `_ || `Optimization `_ || `Save & Load Model `_ Quickstart =================== This section runs through the API for common tasks in machine learning. Refer to the links in each section to dive deeper. Working with data ----------------- PyTorch has two `primitives to work with data `_: ``torch.utils.data.DataLoader`` and ``torch.utils.data.Dataset``. ``Dataset`` stores the samples and their corresponding labels, and ``DataLoader`` wraps an iterable around the ``Dataset``. .. GENERATED FROM PYTHON SOURCE LINES 24-31 .. code-block:: Python import torch from torch import nn from torch.utils.data import DataLoader from torchvision import datasets from torchvision.transforms import ToTensor .. GENERATED FROM PYTHON SOURCE LINES 32-40 PyTorch offers domain-specific libraries such as `TorchText `_, `TorchVision `_, and `TorchAudio `_, all of which include datasets. For this tutorial, we will be using a TorchVision dataset. The ``torchvision.datasets`` module contains ``Dataset`` objects for many real-world vision data like CIFAR, COCO (`full list here `_). In this tutorial, we use the FashionMNIST dataset. Every TorchVision ``Dataset`` includes two arguments: ``transform`` and ``target_transform`` to modify the samples and labels respectively. .. GENERATED FROM PYTHON SOURCE LINES 40-57 .. code-block:: Python # Download training data from open datasets. training_data = datasets.FashionMNIST( root="data", train=True, download=True, transform=ToTensor(), ) # Download test data from open datasets. test_data = datasets.FashionMNIST( root="data", train=False, download=True, transform=ToTensor(), ) .. rst-class:: sphx-glr-script-out .. code-block:: none 0%| | 0.00/26.4M [00:00`_. .. GENERATED FROM PYTHON SOURCE LINES 78-80 -------------- .. GENERATED FROM PYTHON SOURCE LINES 82-89 Creating Models ------------------ To define a neural network in PyTorch, we create a class that inherits from `nn.Module `_. We define the layers of the network in the ``__init__`` function and specify how data will pass through the network in the ``forward`` function. To accelerate operations in the neural network, we move it to the `accelerator `__ such as CUDA, MPS, MTIA, or XPU. If the current accelerator is available, we will use it. Otherwise, we use the CPU. .. GENERATED FROM PYTHON SOURCE LINES 89-114 .. code-block:: Python device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu" print(f"Using {device} device") # Define model class NeuralNetwork(nn.Module): def __init__(self): super().__init__() self.flatten = nn.Flatten() self.linear_relu_stack = nn.Sequential( nn.Linear(28*28, 512), nn.ReLU(), nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, 10) ) def forward(self, x): x = self.flatten(x) logits = self.linear_relu_stack(x) return logits model = NeuralNetwork().to(device) print(model) .. rst-class:: sphx-glr-script-out .. code-block:: none Using cuda device NeuralNetwork( (flatten): Flatten(start_dim=1, end_dim=-1) (linear_relu_stack): Sequential( (0): Linear(in_features=784, out_features=512, bias=True) (1): ReLU() (2): Linear(in_features=512, out_features=512, bias=True) (3): ReLU() (4): Linear(in_features=512, out_features=10, bias=True) ) ) .. GENERATED FROM PYTHON SOURCE LINES 115-117 Read more about `building neural networks in PyTorch `_. .. GENERATED FROM PYTHON SOURCE LINES 120-122 -------------- .. GENERATED FROM PYTHON SOURCE LINES 125-129 Optimizing the Model Parameters ---------------------------------------- To train a model, we need a `loss function `_ and an `optimizer `_. .. GENERATED FROM PYTHON SOURCE LINES 129-134 .. code-block:: Python loss_fn = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) .. GENERATED FROM PYTHON SOURCE LINES 135-137 In a single training loop, the model makes predictions on the training dataset (fed to it in batches), and backpropagates the prediction error to adjust the model's parameters. .. GENERATED FROM PYTHON SOURCE LINES 137-157 .. code-block:: Python def train(dataloader, model, loss_fn, optimizer): size = len(dataloader.dataset) model.train() for batch, (X, y) in enumerate(dataloader): X, y = X.to(device), y.to(device) # Compute prediction error pred = model(X) loss = loss_fn(pred, y) # Backpropagation loss.backward() optimizer.step() optimizer.zero_grad() if batch % 100 == 0: loss, current = loss.item(), (batch + 1) * len(X) print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]") .. GENERATED FROM PYTHON SOURCE LINES 158-159 We also check the model's performance against the test dataset to ensure it is learning. .. GENERATED FROM PYTHON SOURCE LINES 159-175 .. code-block:: Python def test(dataloader, model, loss_fn): size = len(dataloader.dataset) num_batches = len(dataloader) model.eval() test_loss, correct = 0, 0 with torch.no_grad(): for X, y in dataloader: X, y = X.to(device), y.to(device) pred = model(X) test_loss += loss_fn(pred, y).item() correct += (pred.argmax(1) == y).type(torch.float).sum().item() test_loss /= num_batches correct /= size print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n") .. GENERATED FROM PYTHON SOURCE LINES 176-179 The training process is conducted over several iterations (*epochs*). During each epoch, the model learns parameters to make better predictions. We print the model's accuracy and loss at each epoch; we'd like to see the accuracy increase and the loss decrease with every epoch. .. GENERATED FROM PYTHON SOURCE LINES 179-187 .. code-block:: Python epochs = 5 for t in range(epochs): print(f"Epoch {t+1}\n-------------------------------") train(train_dataloader, model, loss_fn, optimizer) test(test_dataloader, model, loss_fn) print("Done!") .. rst-class:: sphx-glr-script-out .. code-block:: none Epoch 1 ------------------------------- loss: 2.298264 [ 64/60000] loss: 2.295813 [ 6464/60000] loss: 2.275808 [12864/60000] loss: 2.272126 [19264/60000] loss: 2.251394 [25664/60000] loss: 2.213269 [32064/60000] loss: 2.238848 [38464/60000] loss: 2.202751 [44864/60000] loss: 2.186646 [51264/60000] loss: 2.157385 [57664/60000] Test Error: Accuracy: 32.9%, Avg loss: 2.154855 Epoch 2 ------------------------------- loss: 2.159304 [ 64/60000] loss: 2.161748 [ 6464/60000] loss: 2.104128 [12864/60000] loss: 2.120124 [19264/60000] loss: 2.071695 [25664/60000] loss: 2.001761 [32064/60000] loss: 2.041755 [38464/60000] loss: 1.964106 [44864/60000] loss: 1.949759 [51264/60000] loss: 1.893443 [57664/60000] Test Error: Accuracy: 59.0%, Avg loss: 1.890260 Epoch 3 ------------------------------- loss: 1.911582 [ 64/60000] loss: 1.901653 [ 6464/60000] loss: 1.781102 [12864/60000] loss: 1.819638 [19264/60000] loss: 1.718602 [25664/60000] loss: 1.652971 [32064/60000] loss: 1.683862 [38464/60000] loss: 1.585548 [44864/60000] loss: 1.591094 [51264/60000] loss: 1.499313 [57664/60000] Test Error: Accuracy: 61.9%, Avg loss: 1.514720 Epoch 4 ------------------------------- loss: 1.567692 [ 64/60000] loss: 1.550673 [ 6464/60000] loss: 1.392059 [12864/60000] loss: 1.470251 [19264/60000] loss: 1.356404 [25664/60000] loss: 1.328737 [32064/60000] loss: 1.361924 [38464/60000] loss: 1.284642 [44864/60000] loss: 1.308086 [51264/60000] loss: 1.216273 [57664/60000] Test Error: Accuracy: 64.3%, Avg loss: 1.242121 Epoch 5 ------------------------------- loss: 1.308153 [ 64/60000] loss: 1.302448 [ 6464/60000] loss: 1.129699 [12864/60000] loss: 1.241600 [19264/60000] loss: 1.125959 [25664/60000] loss: 1.125860 [32064/60000] loss: 1.169651 [38464/60000] loss: 1.104044 [44864/60000] loss: 1.134513 [51264/60000] loss: 1.055284 [57664/60000] Test Error: Accuracy: 65.2%, Avg loss: 1.075801 Done! .. GENERATED FROM PYTHON SOURCE LINES 188-190 Read more about `Training your model `_. .. GENERATED FROM PYTHON SOURCE LINES 192-194 -------------- .. GENERATED FROM PYTHON SOURCE LINES 196-199 Saving Models ------------- A common way to save a model is to serialize the internal state dictionary (containing the model parameters). .. GENERATED FROM PYTHON SOURCE LINES 199-205 .. code-block:: Python torch.save(model.state_dict(), "model.pth") print("Saved PyTorch Model State to model.pth") .. rst-class:: sphx-glr-script-out .. code-block:: none Saved PyTorch Model State to model.pth .. GENERATED FROM PYTHON SOURCE LINES 206-211 Loading Models ---------------------------- The process for loading a model includes re-creating the model structure and loading the state dictionary into it. .. GENERATED FROM PYTHON SOURCE LINES 211-215 .. code-block:: Python model = NeuralNetwork().to(device) model.load_state_dict(torch.load("model.pth", weights_only=True)) .. rst-class:: sphx-glr-script-out .. code-block:: none .. GENERATED FROM PYTHON SOURCE LINES 216-217 This model can now be used to make predictions. .. GENERATED FROM PYTHON SOURCE LINES 217-240 .. code-block:: Python classes = [ "T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot", ] model.eval() x, y = test_data[0][0], test_data[0][1] with torch.no_grad(): x = x.to(device) pred = model(x) predicted, actual = classes[pred[0].argmax(0)], classes[y] print(f'Predicted: "{predicted}", Actual: "{actual}"') .. rst-class:: sphx-glr-script-out .. code-block:: none Predicted: "Ankle boot", Actual: "Ankle boot" .. GENERATED FROM PYTHON SOURCE LINES 241-243 Read more about `Saving & Loading your model `_. .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 34.317 seconds) .. _sphx_glr_download_beginner_basics_quickstart_tutorial.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: quickstart_tutorial.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: quickstart_tutorial.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: quickstart_tutorial.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_