Shortcuts

torchvision.datasets

All datasets are subclasses of torch.utils.data.Dataset i.e, they have __getitem__ and __len__ methods implemented. Hence, they can all be passed to a torch.utils.data.DataLoader which can load multiple samples parallelly using torch.multiprocessing workers. For example:

imagenet_data = torchvision.datasets.ImageFolder('path/to/imagenet_root/')
data_loader = torch.utils.data.DataLoader(imagenet_data,
                                          batch_size=4,
                                          shuffle=True,
                                          num_workers=args.nThreads)

The following datasets are available:

All the datasets have almost similar API. They all have two common arguments: transform and target_transform to transform the input and target respectively.

MNIST

class torchvision.datasets.MNIST(root, train=True, transform=None, target_transform=None, download=False)

MNIST Dataset.

Parameters:
  • root (string) – Root directory of dataset where MNIST/processed/training.pt and MNIST/processed/test.pt exist.
  • train (bool, optional) – If True, creates dataset from training.pt, otherwise from test.pt.
  • download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.

Fashion-MNIST

class torchvision.datasets.FashionMNIST(root, train=True, transform=None, target_transform=None, download=False)

Fashion-MNIST Dataset.

Parameters:
  • root (string) – Root directory of dataset where Fashion-MNIST/processed/training.pt and Fashion-MNIST/processed/test.pt exist.
  • train (bool, optional) – If True, creates dataset from training.pt, otherwise from test.pt.
  • download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.

KMNIST

class torchvision.datasets.KMNIST(root, train=True, transform=None, target_transform=None, download=False)

Kuzushiji-MNIST Dataset.

Parameters:
  • root (string) – Root directory of dataset where KMNIST/processed/training.pt and KMNIST/processed/test.pt exist.
  • train (bool, optional) – If True, creates dataset from training.pt, otherwise from test.pt.
  • download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.

EMNIST

class torchvision.datasets.EMNIST(root, split, **kwargs)

EMNIST Dataset.

Parameters:
  • root (string) – Root directory of dataset where EMNIST/processed/training.pt and EMNIST/processed/test.pt exist.
  • split (string) – The dataset has 6 different splits: byclass, bymerge, balanced, letters, digits and mnist. This argument specifies which one to use.
  • train (bool, optional) – If True, creates dataset from training.pt, otherwise from test.pt.
  • download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.

FakeData

class torchvision.datasets.FakeData(size=1000, image_size=(3, 224, 224), num_classes=10, transform=None, target_transform=None, random_offset=0)

A fake dataset that returns randomly generated images and returns them as PIL images

Parameters:
  • size (int, optional) – Size of the dataset. Default: 1000 images
  • image_size (tuple, optional) – Size if the returned images. Default: (3, 224, 224)
  • num_classes (int, optional) – Number of classes in the datset. Default: 10
  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
  • random_offset (int) – Offsets the index-based random seed used to generate each image. Default: 0

COCO

Note

These require the COCO API to be installed

Captions

class torchvision.datasets.CocoCaptions(root, annFile, transform=None, target_transform=None)

MS Coco Captions Dataset.

Parameters:
  • root (string) – Root directory where images are downloaded to.
  • annFile (string) – Path to json annotation file.
  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.ToTensor
  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.

Example

import torchvision.datasets as dset
import torchvision.transforms as transforms
cap = dset.CocoCaptions(root = 'dir where images are',
                        annFile = 'json annotation file',
                        transform=transforms.ToTensor())

print('Number of samples: ', len(cap))
img, target = cap[3] # load 4th sample

print("Image Size: ", img.size())
print(target)

Output:

Number of samples: 82783
Image Size: (3L, 427L, 640L)
[u'A plane emitting smoke stream flying over a mountain.',
u'A plane darts across a bright blue sky behind a mountain covered in snow',
u'A plane leaves a contrail above the snowy mountain top.',
u'A mountain that has a plane flying overheard in the distance.',
u'A mountain view with a plume of smoke in the background']
__getitem__(index)
Parameters:index (int) – Index
Returns:Tuple (image, target). target is a list of captions for the image.
Return type:tuple

Detection

class torchvision.datasets.CocoDetection(root, annFile, transform=None, target_transform=None)

MS Coco Detection Dataset.

Parameters:
  • root (string) – Root directory where images are downloaded to.
  • annFile (string) – Path to json annotation file.
  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.ToTensor
  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
__getitem__(index)
Parameters:index (int) – Index
Returns:Tuple (image, target). target is the object returned by coco.loadAnns.
Return type:tuple

LSUN

class torchvision.datasets.LSUN(root, classes='train', transform=None, target_transform=None)

LSUN dataset.

Parameters:
  • root (string) – Root directory for the database files.
  • classes (string or list) – One of {‘train’, ‘val’, ‘test’} or a list of categories to load. e,g. [‘bedroom_train’, ‘church_train’].
  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
__getitem__(index)
Parameters:index (int) – Index
Returns:Tuple (image, target) where target is the index of the target category.
Return type:tuple

ImageFolder

class torchvision.datasets.ImageFolder(root, transform=None, target_transform=None, loader=<function default_loader>)

A generic data loader where the images are arranged in this way:

root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png

root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
Parameters:
  • root (string) – Root directory path.
  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
  • loader – A function to load an image given its path.
__getitem__(index)
Parameters:index (int) – Index
Returns:(sample, target) where target is class_index of the target class.
Return type:tuple

DatasetFolder

class torchvision.datasets.DatasetFolder(root, loader, extensions, transform=None, target_transform=None)

A generic data loader where the samples are arranged in this way:

root/class_x/xxx.ext
root/class_x/xxy.ext
root/class_x/xxz.ext

root/class_y/123.ext
root/class_y/nsdf3.ext
root/class_y/asd932_.ext
Parameters:
  • root (string) – Root directory path.
  • loader (callable) – A function to load a sample given its path.
  • extensions (list[string]) – A list of allowed extensions.
  • transform (callable, optional) – A function/transform that takes in a sample and returns a transformed version. E.g, transforms.RandomCrop for images.
  • target_transform – A function/transform that takes in the target and transforms it.
__getitem__(index)
Parameters:index (int) – Index
Returns:(sample, target) where target is class_index of the target class.
Return type:tuple

Imagenet-12

This should simply be implemented with an ImageFolder dataset. The data is preprocessed as described here

Here is an example.

CIFAR

class torchvision.datasets.CIFAR10(root, train=True, transform=None, target_transform=None, download=False)

CIFAR10 Dataset.

Parameters:
  • root (string) – Root directory of dataset where directory cifar-10-batches-py exists or will be saved to if download is set to True.
  • train (bool, optional) – If True, creates dataset from training set, otherwise creates from test set.
  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
  • download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
__getitem__(index)
Parameters:index (int) – Index
Returns:(image, target) where target is index of the target class.
Return type:tuple
class torchvision.datasets.CIFAR100(root, train=True, transform=None, target_transform=None, download=False)

CIFAR100 Dataset.

This is a subclass of the CIFAR10 Dataset.

STL10

class torchvision.datasets.STL10(root, split='train', transform=None, target_transform=None, download=False)

STL10 Dataset.

Parameters:
  • root (string) – Root directory of dataset where directory stl10_binary exists.
  • split (string) – One of {‘train’, ‘test’, ‘unlabeled’, ‘train+unlabeled’}. Accordingly dataset is selected.
  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
  • download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
__getitem__(index)
Parameters:index (int) – Index
Returns:(image, target) where target is index of the target class.
Return type:tuple

SVHN

class torchvision.datasets.SVHN(root, split='train', transform=None, target_transform=None, download=False)

SVHN Dataset. Note: The SVHN dataset assigns the label 10 to the digit 0. However, in this Dataset, we assign the label 0 to the digit 0 to be compatible with PyTorch loss functions which expect the class labels to be in the range [0, C-1]

Parameters:
  • root (string) – Root directory of dataset where directory SVHN exists.
  • split (string) – One of {‘train’, ‘test’, ‘extra’}. Accordingly dataset is selected. ‘extra’ is Extra training set.
  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
  • download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
__getitem__(index)
Parameters:index (int) – Index
Returns:(image, target) where target is index of the target class.
Return type:tuple

PhotoTour

class torchvision.datasets.PhotoTour(root, name, train=True, transform=None, download=False)

Learning Local Image Descriptors Data Dataset.

Parameters:
  • root (string) – Root directory where images are.
  • name (string) – Name of the dataset to load.
  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version.
  • download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
__getitem__(index)
Parameters:index (int) – Index
Returns:(data1, data2, matches)
Return type:tuple

SBU

class torchvision.datasets.SBU(root, transform=None, target_transform=None, download=True)

SBU Captioned Photo Dataset.

Parameters:
  • root (string) – Root directory of dataset where tarball SBUCaptionedPhotoDataset.tar.gz exists.
  • transform (callable, optional) – A function/transform that takes in a PIL image and returns a transformed version. E.g, transforms.RandomCrop
  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
  • download (bool, optional) – If True, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
__getitem__(index)
Parameters:index (int) – Index
Returns:(image, target) where target is a caption for the photo.
Return type:tuple

Flickr

class torchvision.datasets.Flickr8k(root, ann_file, transform=None, target_transform=None)

Flickr8k Entities Dataset.

Parameters:
  • root (string) – Root directory where images are downloaded to.
  • ann_file (string) – Path to annotation file.
  • transform (callable, optional) – A function/transform that takes in a PIL image and returns a transformed version. E.g, transforms.ToTensor
  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
__getitem__(index)
Parameters:index (int) – Index
Returns:Tuple (image, target). target is a list of captions for the image.
Return type:tuple
class torchvision.datasets.Flickr30k(root, ann_file, transform=None, target_transform=None)

Flickr30k Entities Dataset.

Parameters:
  • root (string) – Root directory where images are downloaded to.
  • ann_file (string) – Path to annotation file.
  • transform (callable, optional) – A function/transform that takes in a PIL image and returns a transformed version. E.g, transforms.ToTensor
  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
__getitem__(index)
Parameters:index (int) – Index
Returns:Tuple (image, target). target is a list of captions for the image.
Return type:tuple

VOC

class torchvision.datasets.VOCSegmentation(root, year='2012', image_set='train', download=False, transform=None, target_transform=None)

Pascal VOC Segmentation Dataset.

Parameters:
  • root (string) – Root directory of the VOC Dataset.
  • year (string, optional) – The dataset year, supports years 2007 to 2012.
  • image_set (string, optional) – Select the image_set to use, train, trainval or val
  • download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
__getitem__(index)
Parameters:index (int) – Index
Returns:(image, target) where target is the image segmentation.
Return type:tuple
class torchvision.datasets.VOCDetection(root, year='2012', image_set='train', download=False, transform=None, target_transform=None)

Pascal VOC Detection Dataset.

Parameters:
  • root (string) – Root directory of the VOC Dataset.
  • year (string, optional) – The dataset year, supports years 2007 to 2012.
  • image_set (string, optional) – Select the image_set to use, train, trainval or val
  • download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. (default: alphabetic indexing of VOC’s 20 classes).
  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
  • target_transform (callable, required) – A function/transform that takes in the target and transforms it.
__getitem__(index)
Parameters:index (int) – Index
Returns:(image, target) where target is a dictionary of the XML tree.
Return type:tuple

Cityscapes

Note

Requires Cityscape to be downloaded.

class torchvision.datasets.Cityscapes(root, split='train', mode='fine', target_type='instance', transform=None, target_transform=None)

Cityscapes Dataset.

Parameters:
  • root (string) – Root directory of dataset where directory leftImg8bit and gtFine or gtCoarse are located.
  • split (string, optional) – The image split to use, train, test or val if mode=”gtFine” otherwise train, train_extra or val
  • mode (string, optional) – The quality mode to use, gtFine or gtCoarse
  • target_type (string or list, optional) – Type of target to use, instance, semantic, polygon or color. Can also be a list to output a tuple with all specified target types.
  • transform (callable, optional) – A function/transform that takes in a PIL image and returns a transformed version. E.g, transforms.RandomCrop
  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.

Examples

Get semantic segmentation target

dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
                     target_type='semantic')

img, smnt = dataset[0]

Get multiple targets

dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
                     target_type=['instance', 'color', 'polygon'])

img, (inst, col, poly) = dataset[0]

Validate on the “coarse” set

dataset = Cityscapes('./data/cityscapes', split='val', mode='coarse',
                     target_type='semantic')

img, smnt = dataset[0]
__getitem__(index)
Parameters:index (int) – Index
Returns:(image, target) where target is a tuple of all target types if target_type is a list with more than one item. Otherwise target is a json object if target_type=”polygon”, else the image segmentation.
Return type:tuple

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources