.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "beginner/introyt/trainingyt.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_beginner_introyt_trainingyt.py: `Introduction `_ || `Tensors `_ || `Autograd `_ || `Building Models `_ || `TensorBoard Support `_ || **Training Models** || `Model Understanding `_ Training with PyTorch ===================== Follow along with the video below or on `youtube `__. .. raw:: html
Introduction ------------ In past videos, we’ve discussed and demonstrated: - Building models with the neural network layers and functions of the torch.nn module - The mechanics of automated gradient computation, which is central to gradient-based model training - Using TensorBoard to visualize training progress and other activities In this video, we’ll be adding some new tools to your inventory: - We’ll get familiar with the dataset and dataloader abstractions, and how they ease the process of feeding data to your model during a training loop - We’ll discuss specific loss functions and when to use them - We’ll look at PyTorch optimizers, which implement algorithms to adjust model weights based on the outcome of a loss function Finally, we’ll pull all of these together and see a full PyTorch training loop in action. Dataset and DataLoader ---------------------- The ``Dataset`` and ``DataLoader`` classes encapsulate the process of pulling your data from storage and exposing it to your training loop in batches. The ``Dataset`` is responsible for accessing and processing single instances of data. The ``DataLoader`` pulls instances of data from the ``Dataset`` (either automatically or with a sampler that you define), collects them in batches, and returns them for consumption by your training loop. The ``DataLoader`` works with all kinds of datasets, regardless of the type of data they contain. For this tutorial, we’ll be using the Fashion-MNIST dataset provided by TorchVision. We use ``torchvision.transforms.Normalize()`` to zero-center and normalize the distribution of the image tile content, and download both training and validation data splits. .. GENERATED FROM PYTHON SOURCE LINES 65-96 .. code-block:: Python import torch import torchvision import torchvision.transforms as transforms # PyTorch TensorBoard support from torch.utils.tensorboard import SummaryWriter from datetime import datetime transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) # Create datasets for training & validation, download if necessary training_set = torchvision.datasets.FashionMNIST('./data', train=True, transform=transform, download=True) validation_set = torchvision.datasets.FashionMNIST('./data', train=False, transform=transform, download=True) # Create data loaders for our datasets; shuffle for training, not for validation training_loader = torch.utils.data.DataLoader(training_set, batch_size=4, shuffle=True) validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=4, shuffle=False) # Class labels classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot') # Report split sizes print('Training set has {} instances'.format(len(training_set))) print('Validation set has {} instances'.format(len(validation_set))) .. rst-class:: sphx-glr-script-out .. code-block:: none 0%| | 0.00/26.4M [00:00`__ with momentum. It can be instructive to try some variations on this optimization scheme: - Learning rate determines the size of the steps the optimizer takes. What does a different learning rate do to the your training results, in terms of accuracy and convergence time? - Momentum nudges the optimizer in the direction of strongest gradient over multiple steps. What does changing this value do to your results? - Try some different optimization algorithms, such as averaged SGD, Adagrad, or Adam. How do your results differ? .. GENERATED FROM PYTHON SOURCE LINES 200-205 .. code-block:: Python # Optimizers specified in the torch.optim package optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9) .. GENERATED FROM PYTHON SOURCE LINES 206-225 The Training Loop ----------------- Below, we have a function that performs one training epoch. It enumerates data from the DataLoader, and on each pass of the loop does the following: - Gets a batch of training data from the DataLoader - Zeros the optimizer’s gradients - Performs an inference - that is, gets predictions from the model for an input batch - Calculates the loss for that set of predictions vs. the labels on the dataset - Calculates the backward gradients over the learning weights - Tells the optimizer to perform one learning step - that is, adjust the model’s learning weights based on the observed gradients for this batch, according to the optimization algorithm we chose - It reports on the loss for every 1000 batches. - Finally, it reports the average per-batch loss for the last 1000 batches, for comparison with a validation run .. GENERATED FROM PYTHON SOURCE LINES 225-262 .. code-block:: Python def train_one_epoch(epoch_index, tb_writer): running_loss = 0. last_loss = 0. # Here, we use enumerate(training_loader) instead of # iter(training_loader) so that we can track the batch # index and do some intra-epoch reporting for i, data in enumerate(training_loader): # Every data instance is an input + label pair inputs, labels = data # Zero your gradients for every batch! optimizer.zero_grad() # Make predictions for this batch outputs = model(inputs) # Compute the loss and its gradients loss = loss_fn(outputs, labels) loss.backward() # Adjust learning weights optimizer.step() # Gather data and report running_loss += loss.item() if i % 1000 == 999: last_loss = running_loss / 1000 # loss per batch print(' batch {} loss: {}'.format(i + 1, last_loss)) tb_x = epoch_index * len(training_loader) + i + 1 tb_writer.add_scalar('Loss/train', last_loss, tb_x) running_loss = 0. return last_loss .. GENERATED FROM PYTHON SOURCE LINES 263-276 Per-Epoch Activity ~~~~~~~~~~~~~~~~~~ There are a couple of things we’ll want to do once per epoch: - Perform validation by checking our relative loss on a set of data that was not used for training, and report this - Save a copy of the model Here, we’ll do our reporting in TensorBoard. This will require going to the command line to start TensorBoard, and opening it in another browser tab. .. GENERATED FROM PYTHON SOURCE LINES 276-326 .. code-block:: Python # Initializing in a separate cell so we can easily add more epochs to the same run timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp)) epoch_number = 0 EPOCHS = 5 best_vloss = 1_000_000. for epoch in range(EPOCHS): print('EPOCH {}:'.format(epoch_number + 1)) # Make sure gradient tracking is on, and do a pass over the data model.train(True) avg_loss = train_one_epoch(epoch_number, writer) running_vloss = 0.0 # Set the model to evaluation mode, disabling dropout and using population # statistics for batch normalization. model.eval() # Disable gradient computation and reduce memory consumption. with torch.no_grad(): for i, vdata in enumerate(validation_loader): vinputs, vlabels = vdata voutputs = model(vinputs) vloss = loss_fn(voutputs, vlabels) running_vloss += vloss avg_vloss = running_vloss / (i + 1) print('LOSS train {} valid {}'.format(avg_loss, avg_vloss)) # Log the running loss averaged per batch # for both training and validation writer.add_scalars('Training vs. Validation Loss', { 'Training' : avg_loss, 'Validation' : avg_vloss }, epoch_number + 1) writer.flush() # Track best performance, and save the model's state if avg_vloss < best_vloss: best_vloss = avg_vloss model_path = 'model_{}_{}'.format(timestamp, epoch_number) torch.save(model.state_dict(), model_path) epoch_number += 1 .. rst-class:: sphx-glr-script-out .. code-block:: none EPOCH 1: batch 1000 loss: 1.9576318949013949 batch 2000 loss: 0.9126187145933509 batch 3000 loss: 0.7570243945810944 batch 4000 loss: 0.6829133056234569 batch 5000 loss: 0.6194519917643629 batch 6000 loss: 0.5836719829617069 batch 7000 loss: 0.5557402149150148 batch 8000 loss: 0.535657885234803 batch 9000 loss: 0.5052011218289845 batch 10000 loss: 0.45554780365002806 batch 11000 loss: 0.46591379248257725 batch 12000 loss: 0.43484579073067287 batch 13000 loss: 0.4310119082017918 batch 14000 loss: 0.44020615833428745 batch 15000 loss: 0.402784006381451 LOSS train 0.402784006381451 valid 0.44769349694252014 EPOCH 2: batch 1000 loss: 0.3899975952416134 batch 2000 loss: 0.4153860620104242 batch 3000 loss: 0.3800500175692141 batch 4000 loss: 0.4061927368853649 batch 5000 loss: 0.37216635681514165 batch 6000 loss: 0.3869452087971731 batch 7000 loss: 0.36553758599137653 batch 8000 loss: 0.36490986471352516 batch 9000 loss: 0.3708160625002347 batch 10000 loss: 0.3504008587963763 batch 11000 loss: 0.3758990603135899 batch 12000 loss: 0.3704308765029418 batch 13000 loss: 0.36135593060984683 batch 14000 loss: 0.35064337893083575 batch 15000 loss: 0.3400108248385659 LOSS train 0.3400108248385659 valid 0.34501272439956665 EPOCH 3: batch 1000 loss: 0.3174629740609089 batch 2000 loss: 0.32524551963224074 batch 3000 loss: 0.31630020878865617 batch 4000 loss: 0.32892000097072743 batch 5000 loss: 0.32677057264751785 batch 6000 loss: 0.3121449965004431 batch 7000 loss: 0.3503868164533469 batch 8000 loss: 0.322912187736365 batch 9000 loss: 0.33894188481817766 batch 10000 loss: 0.3280733678800898 batch 11000 loss: 0.3263176130118118 batch 12000 loss: 0.32762141215529117 batch 13000 loss: 0.31733203681408484 batch 14000 loss: 0.326167136561824 batch 15000 loss: 0.3140320948020485 LOSS train 0.3140320948020485 valid 0.33181536197662354 EPOCH 4: batch 1000 loss: 0.282012564426841 batch 2000 loss: 0.30304097851185224 batch 3000 loss: 0.30601019736530727 batch 4000 loss: 0.3027197667012515 batch 5000 loss: 0.29386470701214423 batch 6000 loss: 0.3085742753381419 batch 7000 loss: 0.30387153060534183 batch 8000 loss: 0.30424517345192725 batch 9000 loss: 0.3060073038264309 batch 10000 loss: 0.3032905082749348 batch 11000 loss: 0.2930876870978027 batch 12000 loss: 0.283281282274831 batch 13000 loss: 0.29611325267382926 batch 14000 loss: 0.30703694381329116 batch 15000 loss: 0.27550252123206154 LOSS train 0.27550252123206154 valid 0.3137122392654419 EPOCH 5: batch 1000 loss: 0.27832344647739227 batch 2000 loss: 0.28029874832210405 batch 3000 loss: 0.2761170767475196 batch 4000 loss: 0.269687148014269 batch 5000 loss: 0.2757682931441468 batch 6000 loss: 0.28846762490483707 batch 7000 loss: 0.269388484995281 batch 8000 loss: 0.2872645212863208 batch 9000 loss: 0.29315009839697087 batch 10000 loss: 0.2734479327071067 batch 11000 loss: 0.2710334567155878 batch 12000 loss: 0.28115201150160285 batch 13000 loss: 0.26809077356300987 batch 14000 loss: 0.2664093120649536 batch 15000 loss: 0.27590450757789947 LOSS train 0.27590450757789947 valid 0.3077177107334137 .. GENERATED FROM PYTHON SOURCE LINES 327-369 To load a saved version of the model: .. code:: python saved_model = GarmentClassifier() saved_model.load_state_dict(torch.load(PATH)) Once you’ve loaded the model, it’s ready for whatever you need it for - more training, inference, or analysis. Note that if your model has constructor parameters that affect model structure, you’ll need to provide them and configure the model identically to the state in which it was saved. Other Resources --------------- - Docs on the `data utilities `__, including Dataset and DataLoader, at pytorch.org - A `note on the use of pinned memory `__ for GPU training - Documentation on the datasets available in `TorchVision `__, `TorchText `__, and `TorchAudio `__ - Documentation on the `loss functions `__ available in PyTorch - Documentation on the `torch.optim package `__, which includes optimizers and related tools, such as learning rate scheduling - A detailed `tutorial on saving and loading models `__ - The `Tutorials section of pytorch.org `__ contains tutorials on a broad variety of training tasks, including classification in different domains, generative adversarial networks, reinforcement learning, and more .. rst-class:: sphx-glr-timing **Total running time of the script:** (2 minutes 59.103 seconds) .. _sphx_glr_download_beginner_introyt_trainingyt.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: trainingyt.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: trainingyt.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: trainingyt.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_