.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorials/data_fashion.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials_data_fashion.py: Using TensorDict for datasets ============================= .. GENERATED FROM PYTHON SOURCE LINES 7-12 In this tutorial we demonstrate how ``TensorDict`` can be used to efficiently and transparently load and manage data inside a training pipeline. The tutorial is based heavily on the `PyTorch Quickstart Tutorial `__, but modified to demonstrate use of ``TensorDict``. .. GENERATED FROM PYTHON SOURCE LINES 12-26 .. code-block:: Python import torch import torch.nn as nn from tensordict import MemoryMappedTensor, TensorDict from torch.utils.data import DataLoader from torchvision import datasets from torchvision.transforms import ToTensor device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") .. rst-class:: sphx-glr-script-out .. code-block:: none Using device: cpu .. GENERATED FROM PYTHON SOURCE LINES 27-31 The ``torchvision.datasets`` module contains a number of convenient pre-prepared datasets. In this tutorial we'll use the relatively simple FashionMNIST dataset. Each image is an item of clothing, the objective is to classify the type of clothing in the image (e.g. "Bag", "Sneaker" etc.). .. GENERATED FROM PYTHON SOURCE LINES 31-45 .. code-block:: Python training_data = datasets.FashionMNIST( root="data", train=True, download=True, transform=ToTensor(), ) test_data = datasets.FashionMNIST( root="data", train=False, download=True, transform=ToTensor(), ) .. GENERATED FROM PYTHON SOURCE LINES 46-52 We will create two tensordicts, one each for the training and test data. We create memory-mapped tensors to hold the data. This will allow us to efficiently load batches of transformed data from disk rather than repeatedly load and transform individual images. First we create the :class:`~tensordict.MemoryMappedTensor` containers. .. GENERATED FROM PYTHON SOURCE LINES 52-76 .. code-block:: Python training_data_td = TensorDict( { "images": MemoryMappedTensor.empty( (len(training_data), *training_data[0][0].squeeze().shape), dtype=torch.float32, ), "targets": MemoryMappedTensor.empty((len(training_data),), dtype=torch.int64), }, batch_size=[len(training_data)], device=device, ) test_data_td = TensorDict( { "images": MemoryMappedTensor.empty( (len(test_data), *test_data[0][0].squeeze().shape), dtype=torch.float32 ), "targets": MemoryMappedTensor.empty((len(test_data),), dtype=torch.int64), }, batch_size=[len(test_data)], device=device, ) .. GENERATED FROM PYTHON SOURCE LINES 77-80 Then we can iterate over the data to populate the memory-mapped tensors. This takes a bit of time, but performing the transforms up-front will save repeated effort during training later. .. GENERATED FROM PYTHON SOURCE LINES 80-87 .. code-block:: Python for i, (img, label) in enumerate(training_data): training_data_td[i] = TensorDict({"images": img, "targets": label}, []) for i, (img, label) in enumerate(test_data): test_data_td[i] = TensorDict({"images": img, "targets": label}, []) .. GENERATED FROM PYTHON SOURCE LINES 88-98 DataLoaders ---------------- We'll create DataLoaders from the ``torchvision``-provided Datasets, as well as from our memory-mapped TensorDicts. Since ``TensorDict`` implements ``__len__`` and ``__getitem__`` (and also ``__getitems__``) we can use it like a map-style Dataset and create a ``DataLoader`` directly from it. Note that because ``TensorDict`` can already handle batched indices, there is no need for collation, so we pass the identity function as ``collate_fn``. .. GENERATED FROM PYTHON SOURCE LINES 98-111 .. code-block:: Python batch_size = 64 train_dataloader = DataLoader(training_data, batch_size=batch_size) # noqa: TOR401 test_dataloader = DataLoader(test_data, batch_size=batch_size) # noqa: TOR401 train_dataloader_td = DataLoader( # noqa: TOR401 training_data_td, batch_size=batch_size, collate_fn=lambda x: x ) test_dataloader_td = DataLoader( # noqa: TOR401 test_data_td, batch_size=batch_size, collate_fn=lambda x: x ) .. GENERATED FROM PYTHON SOURCE LINES 112-118 Model ------- We use the same model from the `Quickstart Tutorial `__. .. GENERATED FROM PYTHON SOURCE LINES 118-142 .. code-block:: Python class Net(nn.Module): def __init__(self): super().__init__() self.flatten = nn.Flatten() self.linear_relu_stack = nn.Sequential( nn.Linear(28 * 28, 512), nn.ReLU(), nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, 10), ) def forward(self, x): x = self.flatten(x) logits = self.linear_relu_stack(x) return logits model = Net().to(device) model_td = Net().to(device) model, model_td .. rst-class:: sphx-glr-script-out .. code-block:: none (Net( (flatten): Flatten(start_dim=1, end_dim=-1) (linear_relu_stack): Sequential( (0): Linear(in_features=784, out_features=512, bias=True) (1): ReLU() (2): Linear(in_features=512, out_features=512, bias=True) (3): ReLU() (4): Linear(in_features=512, out_features=10, bias=True) ) ), Net( (flatten): Flatten(start_dim=1, end_dim=-1) (linear_relu_stack): Sequential( (0): Linear(in_features=784, out_features=512, bias=True) (1): ReLU() (2): Linear(in_features=512, out_features=512, bias=True) (3): ReLU() (4): Linear(in_features=512, out_features=10, bias=True) ) )) .. GENERATED FROM PYTHON SOURCE LINES 143-149 Optimizing the parameters --------------------------------- We'll optimise the parameters of the model using stochastic gradient descent and cross-entropy loss. .. GENERATED FROM PYTHON SOURCE LINES 149-174 .. code-block:: Python loss_fn = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) optimizer_td = torch.optim.SGD(model_td.parameters(), lr=1e-3) def train(dataloader, model, loss_fn, optimizer): size = len(dataloader.dataset) model.train() for batch, (X, y) in enumerate(dataloader): X, y = X.to(device), y.to(device) pred = model(X) loss = loss_fn(pred, y) optimizer.zero_grad() loss.backward() optimizer.step() if batch % 100 == 0: loss, current = loss.item(), batch * len(X) print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]") .. GENERATED FROM PYTHON SOURCE LINES 175-178 The training loop for our ``TensorDict``-based DataLoader is very similar, we just adjust how we unpack the data to the more explicit key-based retrieval offered by ``TensorDict``. The ``.contiguous()`` method loads the data stored in the memmap tensor. .. GENERATED FROM PYTHON SOURCE LINES 178-264 .. code-block:: Python def train_td(dataloader, model, loss_fn, optimizer): size = len(dataloader.dataset) model.train() for batch, data in enumerate(dataloader): X, y = data["images"].contiguous(), data["targets"].contiguous() pred = model(X) loss = loss_fn(pred, y) optimizer.zero_grad() loss.backward() optimizer.step() if batch % 100 == 0: loss, current = loss.item(), batch * len(X) print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]") def test(dataloader, model, loss_fn): size = len(dataloader.dataset) num_batches = len(dataloader) model.eval() test_loss, correct = 0, 0 with torch.no_grad(): for X, y in dataloader: X, y = X.to(device), y.to(device) pred = model(X) test_loss += loss_fn(pred, y).item() correct += (pred.argmax(1) == y).type(torch.float).sum().item() test_loss /= num_batches correct /= size print( f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n" ) def test_td(dataloader, model, loss_fn): size = len(dataloader.dataset) num_batches = len(dataloader) model.eval() test_loss, correct = 0, 0 with torch.no_grad(): for batch in dataloader: X, y = batch["images"].contiguous(), batch["targets"].contiguous() pred = model(X) test_loss += loss_fn(pred, y).item() correct += (pred.argmax(1) == y).type(torch.float).sum().item() test_loss /= num_batches correct /= size print( f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n" ) for d in train_dataloader_td: print(d) break import time t0 = time.time() epochs = 5 for t in range(epochs): print(f"Epoch {t + 1}\n-------------------------") train_td(train_dataloader_td, model_td, loss_fn, optimizer_td) test_td(test_dataloader_td, model_td, loss_fn) print(f"TensorDict training done! time: {time.time() - t0: 4.4f} s") t0 = time.time() epochs = 5 for t in range(epochs): print(f"Epoch {t + 1}\n-------------------------") train(train_dataloader, model, loss_fn, optimizer) test(test_dataloader, model, loss_fn) print(f"Training done! time: {time.time() - t0: 4.4f} s") .. rst-class:: sphx-glr-script-out .. code-block:: none TensorDict( fields={ images: Tensor(shape=torch.Size([64, 28, 28]), device=cpu, dtype=torch.float32, is_shared=False), targets: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([64]), device=cpu, is_shared=False) Epoch 1 ------------------------- loss: 2.295408 [ 0/60000] loss: 2.295070 [ 6400/60000] loss: 2.267567 [12800/60000] loss: 2.267833 [19200/60000] loss: 2.253610 [25600/60000] loss: 2.212193 [32000/60000] loss: 2.234124 [38400/60000] loss: 2.182492 [44800/60000] loss: 2.187164 [51200/60000] loss: 2.159223 [57600/60000] Test Error: Accuracy: 41.1%, Avg loss: 2.149606 Epoch 2 ------------------------- loss: 2.157076 [ 0/60000] loss: 2.153280 [ 6400/60000] loss: 2.090101 [12800/60000] loss: 2.108530 [19200/60000] loss: 2.052299 [25600/60000] loss: 1.985130 [32000/60000] loss: 2.030497 [38400/60000] loss: 1.938622 [44800/60000] loss: 1.952904 [51200/60000] loss: 1.871847 [57600/60000] Test Error: Accuracy: 54.4%, Avg loss: 1.870125 Epoch 3 ------------------------- loss: 1.906839 [ 0/60000] loss: 1.875080 [ 6400/60000] loss: 1.755180 [12800/60000] loss: 1.795452 [19200/60000] loss: 1.681108 [25600/60000] loss: 1.633004 [32000/60000] loss: 1.671840 [38400/60000] loss: 1.565841 [44800/60000] loss: 1.594627 [51200/60000] loss: 1.487918 [57600/60000] Test Error: Accuracy: 61.1%, Avg loss: 1.504041 Epoch 4 ------------------------- loss: 1.571545 [ 0/60000] loss: 1.538728 [ 6400/60000] loss: 1.387533 [12800/60000] loss: 1.457100 [19200/60000] loss: 1.343402 [25600/60000] loss: 1.337026 [32000/60000] loss: 1.367137 [38400/60000] loss: 1.283840 [44800/60000] loss: 1.317707 [51200/60000] loss: 1.224614 [57600/60000] Test Error: Accuracy: 63.6%, Avg loss: 1.245886 Epoch 5 ------------------------- loss: 1.318024 [ 0/60000] loss: 1.306073 [ 6400/60000] loss: 1.138275 [12800/60000] loss: 1.239986 [19200/60000] loss: 1.122743 [25600/60000] loss: 1.143268 [32000/60000] loss: 1.178220 [38400/60000] loss: 1.106306 [44800/60000] loss: 1.144859 [51200/60000] loss: 1.067169 [57600/60000] Test Error: Accuracy: 64.6%, Avg loss: 1.084245 TensorDict training done! time: 8.6040 s Epoch 1 ------------------------- loss: 2.301039 [ 0/60000] loss: 2.294237 [ 6400/60000] loss: 2.258980 [12800/60000] loss: 2.257846 [19200/60000] loss: 2.258250 [25600/60000] loss: 2.210056 [32000/60000] loss: 2.227523 [38400/60000] loss: 2.191480 [44800/60000] loss: 2.194841 [51200/60000] loss: 2.157639 [57600/60000] Test Error: Accuracy: 41.5%, Avg loss: 2.149883 Epoch 2 ------------------------- loss: 2.162645 [ 0/60000] loss: 2.161261 [ 6400/60000] loss: 2.088595 [12800/60000] loss: 2.110117 [19200/60000] loss: 2.075595 [25600/60000] loss: 1.998082 [32000/60000] loss: 2.034677 [38400/60000] loss: 1.953384 [44800/60000] loss: 1.964688 [51200/60000] loss: 1.889554 [57600/60000] Test Error: Accuracy: 57.2%, Avg loss: 1.882470 Epoch 3 ------------------------- loss: 1.917056 [ 0/60000] loss: 1.898173 [ 6400/60000] loss: 1.765086 [12800/60000] loss: 1.809607 [19200/60000] loss: 1.710262 [25600/60000] loss: 1.648017 [32000/60000] loss: 1.678248 [38400/60000] loss: 1.571657 [44800/60000] loss: 1.604326 [51200/60000] loss: 1.496464 [57600/60000] Test Error: Accuracy: 59.4%, Avg loss: 1.507853 Epoch 4 ------------------------- loss: 1.575939 [ 0/60000] loss: 1.552139 [ 6400/60000] loss: 1.383423 [12800/60000] loss: 1.462275 [19200/60000] loss: 1.353662 [25600/60000] loss: 1.336903 [32000/60000] loss: 1.360661 [38400/60000] loss: 1.275323 [44800/60000] loss: 1.318643 [51200/60000] loss: 1.217539 [57600/60000] Test Error: Accuracy: 62.6%, Avg loss: 1.239839 Epoch 5 ------------------------- loss: 1.315363 [ 0/60000] loss: 1.310467 [ 6400/60000] loss: 1.124107 [12800/60000] loss: 1.241619 [19200/60000] loss: 1.127850 [25600/60000] loss: 1.138872 [32000/60000] loss: 1.170922 [38400/60000] loss: 1.097719 [44800/60000] loss: 1.142735 [51200/60000] loss: 1.058941 [57600/60000] Test Error: Accuracy: 64.3%, Avg loss: 1.077324 Training done! time: 33.9711 s .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 55.322 seconds) .. _sphx_glr_download_tutorials_data_fashion.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: data_fashion.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: data_fashion.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: data_fashion.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_