.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorials/tensorclass_fashion.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials_tensorclass_fashion.py: Using tensorclasses for datasets ================================ .. GENERATED FROM PYTHON SOURCE LINES 7-13 In this tutorial we demonstrate how tensorclasses can be used to efficiently and transparently load and manage data inside a training pipeline. The tutorial is based heavily on the `PyTorch Quickstart Tutorial `__, but modified to demonstrate use of tensorclass. See the related tutorial using ``TensorDict``. .. GENERATED FROM PYTHON SOURCE LINES 13-27 .. code-block:: Python import torch import torch.nn as nn from tensordict import MemoryMappedTensor, tensorclass from torch.utils.data import DataLoader from torchvision import datasets from torchvision.transforms import ToTensor device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") .. rst-class:: sphx-glr-script-out .. code-block:: none Using device: cpu .. GENERATED FROM PYTHON SOURCE LINES 28-32 The ``torchvision.datasets`` module contains a number of convenient pre-prepared datasets. In this tutorial we'll use the relatively simple FashionMNIST dataset. Each image is an item of clothing, the objective is to classify the type of clothing in the image (e.g. "Bag", "Sneaker" etc.). .. GENERATED FROM PYTHON SOURCE LINES 32-46 .. code-block:: Python training_data = datasets.FashionMNIST( root="data", train=True, download=True, transform=ToTensor(), ) test_data = datasets.FashionMNIST( root="data", train=False, download=True, transform=ToTensor(), ) .. rst-class:: sphx-glr-script-out .. code-block:: none 0%| | 0.00/26.4M [00:00`__. .. GENERATED FROM PYTHON SOURCE LINES 120-144 .. code-block:: Python class Net(nn.Module): def __init__(self): super().__init__() self.flatten = nn.Flatten() self.linear_relu_stack = nn.Sequential( nn.Linear(28 * 28, 512), nn.ReLU(), nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, 10), ) def forward(self, x): x = self.flatten(x) logits = self.linear_relu_stack(x) return logits model = Net().to(device) model_tc = Net().to(device) model, model_tc .. rst-class:: sphx-glr-script-out .. code-block:: none (Net( (flatten): Flatten(start_dim=1, end_dim=-1) (linear_relu_stack): Sequential( (0): Linear(in_features=784, out_features=512, bias=True) (1): ReLU() (2): Linear(in_features=512, out_features=512, bias=True) (3): ReLU() (4): Linear(in_features=512, out_features=10, bias=True) ) ), Net( (flatten): Flatten(start_dim=1, end_dim=-1) (linear_relu_stack): Sequential( (0): Linear(in_features=784, out_features=512, bias=True) (1): ReLU() (2): Linear(in_features=512, out_features=512, bias=True) (3): ReLU() (4): Linear(in_features=512, out_features=10, bias=True) ) )) .. GENERATED FROM PYTHON SOURCE LINES 145-151 Optimizing the parameters --------------------------------- We'll optimise the parameters of the model using stochastic gradient descent and cross-entropy loss. .. GENERATED FROM PYTHON SOURCE LINES 151-176 .. code-block:: Python loss_fn = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) optimizer_tc = torch.optim.SGD(model_tc.parameters(), lr=1e-3) def train(dataloader, model, loss_fn, optimizer): size = len(dataloader.dataset) model.train() for batch, (X, y) in enumerate(dataloader): X, y = X.to(device), y.to(device) pred = model(X) loss = loss_fn(pred, y) optimizer.zero_grad() loss.backward() optimizer.step() if batch % 100 == 0: loss, current = loss.item(), batch * len(X) print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]") .. GENERATED FROM PYTHON SOURCE LINES 177-181 The training loop for our tensorclass-based DataLoader is very similar, we just adjust how we unpack the data to the more explicit attribute-based retrieval offered by the tensorclass. The ``.contiguous()`` method loads the data stored in the memmap tensor. .. GENERATED FROM PYTHON SOURCE LINES 181-267 .. code-block:: Python def train_tc(dataloader, model, loss_fn, optimizer): size = len(dataloader.dataset) model.train() for batch, data in enumerate(dataloader): X, y = data.images.contiguous(), data.targets.contiguous() pred = model(X) loss = loss_fn(pred, y) optimizer.zero_grad() loss.backward() optimizer.step() if batch % 100 == 0: loss, current = loss.item(), batch * len(X) print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]") def test(dataloader, model, loss_fn): size = len(dataloader.dataset) num_batches = len(dataloader) model.eval() test_loss, correct = 0, 0 with torch.no_grad(): for X, y in dataloader: X, y = X.to(device), y.to(device) pred = model(X) test_loss += loss_fn(pred, y).item() correct += (pred.argmax(1) == y).type(torch.float).sum().item() test_loss /= num_batches correct /= size print( f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n" ) def test_tc(dataloader, model, loss_fn): size = len(dataloader.dataset) num_batches = len(dataloader) model.eval() test_loss, correct = 0, 0 with torch.no_grad(): for batch in dataloader: X, y = batch.images.contiguous(), batch.targets.contiguous() pred = model(X) test_loss += loss_fn(pred, y).item() correct += (pred.argmax(1) == y).type(torch.float).sum().item() test_loss /= num_batches correct /= size print( f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n" ) for d in train_dataloader_tc: print(d) break import time t0 = time.time() epochs = 5 for t in range(epochs): print(f"Epoch {t + 1}\n-------------------------") train_tc(train_dataloader_tc, model_tc, loss_fn, optimizer_tc) test_tc(test_dataloader_tc, model_tc, loss_fn) print(f"Tensorclass training done! time: {time.time() - t0: 4.4f} s") t0 = time.time() epochs = 5 for t in range(epochs): print(f"Epoch {t + 1}\n-------------------------") train(train_dataloader, model, loss_fn, optimizer) test(test_dataloader, model, loss_fn) print(f"Training done! time: {time.time() - t0: 4.4f} s") .. rst-class:: sphx-glr-script-out .. code-block:: none FashionMNISTData( images=Tensor(shape=torch.Size([64, 28, 28]), device=cpu, dtype=torch.float32, is_shared=False), targets=Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int64, is_shared=False), batch_size=torch.Size([64]), device=cpu, is_shared=False) Epoch 1 ------------------------- loss: 2.313742 [ 0/60000] loss: 2.292799 [ 6400/60000] loss: 2.273777 [12800/60000] loss: 2.262293 [19200/60000] loss: 2.247941 [25600/60000] loss: 2.215127 [32000/60000] loss: 2.227312 [38400/60000] loss: 2.190220 [44800/60000] loss: 2.197131 [51200/60000] loss: 2.159587 [57600/60000] Test Error: Accuracy: 43.4%, Avg loss: 2.153076 Epoch 2 ------------------------- loss: 2.173764 [ 0/60000] loss: 2.164063 [ 6400/60000] loss: 2.105219 [12800/60000] loss: 2.114054 [19200/60000] loss: 2.078059 [25600/60000] loss: 2.006772 [32000/60000] loss: 2.038938 [38400/60000] loss: 1.959138 [44800/60000] loss: 1.970101 [51200/60000] loss: 1.900703 [57600/60000] Test Error: Accuracy: 54.0%, Avg loss: 1.898236 Epoch 3 ------------------------- loss: 1.938516 [ 0/60000] loss: 1.915230 [ 6400/60000] loss: 1.795532 [12800/60000] loss: 1.824228 [19200/60000] loss: 1.741094 [25600/60000] loss: 1.673038 [32000/60000] loss: 1.700301 [38400/60000] loss: 1.593862 [44800/60000] loss: 1.624166 [51200/60000] loss: 1.517918 [57600/60000] Test Error: Accuracy: 58.3%, Avg loss: 1.534553 Epoch 4 ------------------------- loss: 1.606797 [ 0/60000] loss: 1.572791 [ 6400/60000] loss: 1.417350 [12800/60000] loss: 1.481222 [19200/60000] loss: 1.380749 [25600/60000] loss: 1.354327 [32000/60000] loss: 1.376374 [38400/60000] loss: 1.288244 [44800/60000] loss: 1.332447 [51200/60000] loss: 1.231878 [57600/60000] Test Error: Accuracy: 62.5%, Avg loss: 1.260673 Epoch 5 ------------------------- loss: 1.339972 [ 0/60000] loss: 1.320827 [ 6400/60000] loss: 1.154497 [12800/60000] loss: 1.254584 [19200/60000] loss: 1.143303 [25600/60000] loss: 1.151989 [32000/60000] loss: 1.179488 [38400/60000] loss: 1.103132 [44800/60000] loss: 1.149763 [51200/60000] loss: 1.070885 [57600/60000] Test Error: Accuracy: 64.2%, Avg loss: 1.093797 Tensorclass training done! time: 8.9774 s Epoch 1 ------------------------- loss: 2.290750 [ 0/60000] loss: 2.279749 [ 6400/60000] loss: 2.263475 [12800/60000] loss: 2.272084 [19200/60000] loss: 2.240863 [25600/60000] loss: 2.212507 [32000/60000] loss: 2.217540 [38400/60000] loss: 2.184312 [44800/60000] loss: 2.187586 [51200/60000] loss: 2.150691 [57600/60000] Test Error: Accuracy: 36.3%, Avg loss: 2.144585 Epoch 2 ------------------------- loss: 2.149138 [ 0/60000] loss: 2.136008 [ 6400/60000] loss: 2.081745 [12800/60000] loss: 2.106633 [19200/60000] loss: 2.041156 [25600/60000] loss: 1.983334 [32000/60000] loss: 2.008341 [38400/60000] loss: 1.930886 [44800/60000] loss: 1.933684 [51200/60000] loss: 1.852782 [57600/60000] Test Error: Accuracy: 57.1%, Avg loss: 1.857402 Epoch 3 ------------------------- loss: 1.890338 [ 0/60000] loss: 1.852019 [ 6400/60000] loss: 1.738609 [12800/60000] loss: 1.783138 [19200/60000] loss: 1.677351 [25600/60000] loss: 1.630616 [32000/60000] loss: 1.650019 [38400/60000] loss: 1.562767 [44800/60000] loss: 1.580777 [51200/60000] loss: 1.469991 [57600/60000] Test Error: Accuracy: 63.2%, Avg loss: 1.498744 Epoch 4 ------------------------- loss: 1.562999 [ 0/60000] loss: 1.527891 [ 6400/60000] loss: 1.382990 [12800/60000] loss: 1.454048 [19200/60000] loss: 1.351292 [25600/60000] loss: 1.340383 [32000/60000] loss: 1.353094 [38400/60000] loss: 1.288085 [44800/60000] loss: 1.318156 [51200/60000] loss: 1.213383 [57600/60000] Test Error: Accuracy: 64.3%, Avg loss: 1.246422 Epoch 5 ------------------------- loss: 1.317010 [ 0/60000] loss: 1.303447 [ 6400/60000] loss: 1.140213 [12800/60000] loss: 1.241575 [19200/60000] loss: 1.133798 [25600/60000] loss: 1.147807 [32000/60000] loss: 1.167390 [38400/60000] loss: 1.111624 [44800/60000] loss: 1.149838 [51200/60000] loss: 1.060279 [57600/60000] Test Error: Accuracy: 65.1%, Avg loss: 1.086045 Training done! time: 34.2014 s .. rst-class:: sphx-glr-timing **Total running time of the script:** (1 minutes 0.842 seconds) .. _sphx_glr_download_tutorials_tensorclass_fashion.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: tensorclass_fashion.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: tensorclass_fashion.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: tensorclass_fashion.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_