Shortcuts

torchvision.datasets

All datasets are subclasses of torch.utils.data.Dataset i.e, they have __getitem__ and __len__ methods implemented. Hence, they can all be passed to a torch.utils.data.DataLoader which can load multiple samples parallelly using torch.multiprocessing workers. For example:

imagenet_data = torchvision.datasets.ImageNet('path/to/imagenet_root/')
data_loader = torch.utils.data.DataLoader(imagenet_data,
                                          batch_size=4,
                                          shuffle=True,
                                          num_workers=args.nThreads)

The following datasets are available:

All the datasets have almost similar API. They all have two common arguments: transform and target_transform to transform the input and target respectively.

MNIST

class torchvision.datasets.MNIST(root, train=True, transform=None, target_transform=None, download=False)[source]

MNIST Dataset.

Parameters
  • root (string) – Root directory of dataset where MNIST/processed/training.pt and MNIST/processed/test.pt exist.

  • train (bool, optional) – If True, creates dataset from training.pt, otherwise from test.pt.

  • download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.

  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop

  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.

Fashion-MNIST

class torchvision.datasets.FashionMNIST(root, train=True, transform=None, target_transform=None, download=False)[source]

Fashion-MNIST Dataset.

Parameters
  • root (string) – Root directory of dataset where Fashion-MNIST/processed/training.pt and Fashion-MNIST/processed/test.pt exist.

  • train (bool, optional) – If True, creates dataset from training.pt, otherwise from test.pt.

  • download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.

  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop

  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.

KMNIST

class torchvision.datasets.KMNIST(root, train=True, transform=None, target_transform=None, download=False)[source]

Kuzushiji-MNIST Dataset.

Parameters
  • root (string) – Root directory of dataset where KMNIST/processed/training.pt and KMNIST/processed/test.pt exist.

  • train (bool, optional) – If True, creates dataset from training.pt, otherwise from test.pt.

  • download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.

  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop

  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.

EMNIST

class torchvision.datasets.EMNIST(root, split, **kwargs)[source]

EMNIST Dataset.

Parameters
  • root (string) – Root directory of dataset where EMNIST/processed/training.pt and EMNIST/processed/test.pt exist.

  • split (string) – The dataset has 6 different splits: byclass, bymerge, balanced, letters, digits and mnist. This argument specifies which one to use.

  • train (bool, optional) – If True, creates dataset from training.pt, otherwise from test.pt.

  • download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.

  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop

  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.

FakeData

class torchvision.datasets.FakeData(size=1000, image_size=(3, 224, 224), num_classes=10, transform=None, target_transform=None, random_offset=0)[source]

A fake dataset that returns randomly generated images and returns them as PIL images

Parameters
  • size (int, optional) – Size of the dataset. Default: 1000 images

  • image_size (tuple, optional) – Size if the returned images. Default: (3, 224, 224)

  • num_classes (int, optional) – Number of classes in the datset. Default: 10

  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop

  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.

  • random_offset (int) – Offsets the index-based random seed used to generate each image. Default: 0

COCO

Note

These require the COCO API to be installed

Captions

class torchvision.datasets.CocoCaptions(root, annFile, transform=None, target_transform=None, transforms=None)[source]

MS Coco Captions Dataset.

Parameters
  • root (string) – Root directory where images are downloaded to.

  • annFile (string) – Path to json annotation file.

  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.ToTensor

  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.

Example

import torchvision.datasets as dset
import torchvision.transforms as transforms
cap = dset.CocoCaptions(root = 'dir where images are',
                        annFile = 'json annotation file',
                        transform=transforms.ToTensor())

print('Number of samples: ', len(cap))
img, target = cap[3] # load 4th sample

print("Image Size: ", img.size())
print(target)

Output:

Number of samples: 82783
Image Size: (3L, 427L, 640L)
[u'A plane emitting smoke stream flying over a mountain.',
u'A plane darts across a bright blue sky behind a mountain covered in snow',
u'A plane leaves a contrail above the snowy mountain top.',
u'A mountain that has a plane flying overheard in the distance.',
u'A mountain view with a plume of smoke in the background']
__getitem__(index)[source]
Parameters

index (int) – Index

Returns

Tuple (image, target). target is a list of captions for the image.

Return type

tuple

Detection

class torchvision.datasets.CocoDetection(root, annFile, transform=None, target_transform=None, transforms=None)[source]

MS Coco Detection Dataset.

Parameters
  • root (string) – Root directory where images are downloaded to.

  • annFile (string) – Path to json annotation file.

  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.ToTensor

  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.

__getitem__(index)[source]
Parameters

index (int) – Index

Returns

Tuple (image, target). target is the object returned by coco.loadAnns.

Return type

tuple

LSUN

class torchvision.datasets.LSUN(root, classes='train', transform=None, target_transform=None)[source]

LSUN dataset.

Parameters
  • root (string) – Root directory for the database files.

  • classes (string or list) – One of {‘train’, ‘val’, ‘test’} or a list of categories to load. e,g. [‘bedroom_train’, ‘church_train’].

  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop

  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.

__getitem__(index)[source]
Parameters

index (int) – Index

Returns

Tuple (image, target) where target is the index of the target category.

Return type

tuple

ImageFolder

class torchvision.datasets.ImageFolder(root, transform=None, target_transform=None, loader=<function default_loader>, is_valid_file=None)[source]

A generic data loader where the images are arranged in this way:

root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png

root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
Parameters
  • root (string) – Root directory path.

  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop

  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.

  • loader (callable, optional) – A function to load an image given its path.

  • is_valid_file – A function that takes path of an Image file and check if the file is a valid_file (used to check of corrupt files)

__getitem__(index)
Parameters

index (int) – Index

Returns

(sample, target) where target is class_index of the target class.

Return type

tuple

DatasetFolder

class torchvision.datasets.DatasetFolder(root, loader, extensions=None, transform=None, target_transform=None, is_valid_file=None)[source]

A generic data loader where the samples are arranged in this way:

root/class_x/xxx.ext
root/class_x/xxy.ext
root/class_x/xxz.ext

root/class_y/123.ext
root/class_y/nsdf3.ext
root/class_y/asd932_.ext
Parameters
  • root (string) – Root directory path.

  • loader (callable) – A function to load a sample given its path.

  • extensions (tuple[string]) – A list of allowed extensions. both extensions and is_valid_file should not be passed.

  • transform (callable, optional) – A function/transform that takes in a sample and returns a transformed version. E.g, transforms.RandomCrop for images.

  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.

  • is_valid_file – A function that takes path of an Image file and check if the file is a valid_file (used to check of corrupt files) both extensions and is_valid_file should not be passed.

__getitem__(index)[source]
Parameters

index (int) – Index

Returns

(sample, target) where target is class_index of the target class.

Return type

tuple

ImageNet

class torchvision.datasets.ImageNet(root, split='train', download=False, **kwargs)[source]

ImageNet 2012 Classification Dataset.

Parameters
  • root (string) – Root directory of the ImageNet Dataset.

  • split (string, optional) – The dataset split, supports train, or val.

  • download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.

  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop

  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.

  • loader – A function to load an image given its path.

Note

This requires scipy to be installed

CIFAR

class torchvision.datasets.CIFAR10(root, train=True, transform=None, target_transform=None, download=False)[source]

CIFAR10 Dataset.

Parameters
  • root (string) – Root directory of dataset where directory cifar-10-batches-py exists or will be saved to if download is set to True.

  • train (bool, optional) – If True, creates dataset from training set, otherwise creates from test set.

  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop

  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.

  • download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.

__getitem__(index)[source]
Parameters

index (int) – Index

Returns

(image, target) where target is index of the target class.

Return type

tuple

class torchvision.datasets.CIFAR100(root, train=True, transform=None, target_transform=None, download=False)[source]

CIFAR100 Dataset.

This is a subclass of the CIFAR10 Dataset.

STL10

class torchvision.datasets.STL10(root, split='train', transform=None, target_transform=None, download=False)[source]

STL10 Dataset.

Parameters
  • root (string) – Root directory of dataset where directory stl10_binary exists.

  • split (string) – One of {‘train’, ‘test’, ‘unlabeled’, ‘train+unlabeled’}. Accordingly dataset is selected.

  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop

  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.

  • download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.

__getitem__(index)[source]
Parameters

index (int) – Index

Returns

(image, target) where target is index of the target class.

Return type

tuple

SVHN

class torchvision.datasets.SVHN(root, split='train', transform=None, target_transform=None, download=False)[source]

SVHN Dataset. Note: The SVHN dataset assigns the label 10 to the digit 0. However, in this Dataset, we assign the label 0 to the digit 0 to be compatible with PyTorch loss functions which expect the class labels to be in the range [0, C-1]

Parameters
  • root (string) – Root directory of dataset where directory SVHN exists.

  • split (string) – One of {‘train’, ‘test’, ‘extra’}. Accordingly dataset is selected. ‘extra’ is Extra training set.

  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop

  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.

  • download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.

__getitem__(index)[source]
Parameters

index (int) – Index

Returns

(image, target) where target is index of the target class.

Return type

tuple

PhotoTour

class torchvision.datasets.PhotoTour(root, name, train=True, transform=None, download=False)[source]

Learning Local Image Descriptors Data Dataset.

Parameters
  • root (string) – Root directory where images are.

  • name (string) – Name of the dataset to load.

  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version.

  • download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.

__getitem__(index)[source]
Parameters

index (int) – Index

Returns

(data1, data2, matches)

Return type

tuple

SBU

class torchvision.datasets.SBU(root, transform=None, target_transform=None, download=True)[source]

SBU Captioned Photo Dataset.

Parameters
  • root (string) – Root directory of dataset where tarball SBUCaptionedPhotoDataset.tar.gz exists.

  • transform (callable, optional) – A function/transform that takes in a PIL image and returns a transformed version. E.g, transforms.RandomCrop

  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.

  • download (bool, optional) – If True, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.

__getitem__(index)[source]
Parameters

index (int) – Index

Returns

(image, target) where target is a caption for the photo.

Return type

tuple

Flickr

class torchvision.datasets.Flickr8k(root, ann_file, transform=None, target_transform=None)[source]

Flickr8k Entities Dataset.

Parameters
  • root (string) – Root directory where images are downloaded to.

  • ann_file (string) – Path to annotation file.

  • transform (callable, optional) – A function/transform that takes in a PIL image and returns a transformed version. E.g, transforms.ToTensor

  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.

__getitem__(index)[source]
Parameters

index (int) – Index

Returns

Tuple (image, target). target is a list of captions for the image.

Return type

tuple

class torchvision.datasets.Flickr30k(root, ann_file, transform=None, target_transform=None)[source]

Flickr30k Entities Dataset.

Parameters
  • root (string) – Root directory where images are downloaded to.

  • ann_file (string) – Path to annotation file.

  • transform (callable, optional) – A function/transform that takes in a PIL image and returns a transformed version. E.g, transforms.ToTensor

  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.

__getitem__(index)[source]
Parameters

index (int) – Index

Returns

Tuple (image, target). target is a list of captions for the image.

Return type

tuple

VOC

class torchvision.datasets.VOCSegmentation(root, year='2012', image_set='train', download=False, transform=None, target_transform=None, transforms=None)[source]

Pascal VOC Segmentation Dataset.

Parameters
  • root (string) – Root directory of the VOC Dataset.

  • year (string, optional) – The dataset year, supports years 2007 to 2012.

  • image_set (string, optional) – Select the image_set to use, train, trainval or val

  • download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.

  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop

  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.

__getitem__(index)[source]
Parameters

index (int) – Index

Returns

(image, target) where target is the image segmentation.

Return type

tuple

class torchvision.datasets.VOCDetection(root, year='2012', image_set='train', download=False, transform=None, target_transform=None, transforms=None)[source]

Pascal VOC Detection Dataset.

Parameters
  • root (string) – Root directory of the VOC Dataset.

  • year (string, optional) – The dataset year, supports years 2007 to 2012.

  • image_set (string, optional) – Select the image_set to use, train, trainval or val

  • download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. (default: alphabetic indexing of VOC’s 20 classes).

  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop

  • target_transform (callable, required) – A function/transform that takes in the target and transforms it.

__getitem__(index)[source]
Parameters

index (int) – Index

Returns

(image, target) where target is a dictionary of the XML tree.

Return type

tuple

Cityscapes

Note

Requires Cityscape to be downloaded.

class torchvision.datasets.Cityscapes(root, split='train', mode='fine', target_type='instance', transform=None, target_transform=None)[source]

Cityscapes Dataset.

Parameters
  • root (string) – Root directory of dataset where directory leftImg8bit and gtFine or gtCoarse are located.

  • split (string, optional) – The image split to use, train, test or val if mode=”gtFine” otherwise train, train_extra or val

  • mode (string, optional) – The quality mode to use, gtFine or gtCoarse

  • target_type (string or list, optional) – Type of target to use, instance, semantic, polygon or color. Can also be a list to output a tuple with all specified target types.

  • transform (callable, optional) – A function/transform that takes in a PIL image and returns a transformed version. E.g, transforms.RandomCrop

  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.

Examples

Get semantic segmentation target

dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
                     target_type='semantic')

img, smnt = dataset[0]

Get multiple targets

dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
                     target_type=['instance', 'color', 'polygon'])

img, (inst, col, poly) = dataset[0]

Validate on the “coarse” set

dataset = Cityscapes('./data/cityscapes', split='val', mode='coarse',
                     target_type='semantic')

img, smnt = dataset[0]
__getitem__(index)[source]
Parameters

index (int) – Index

Returns

(image, target) where target is a tuple of all target types if target_type is a list with more than one item. Otherwise target is a json object if target_type=”polygon”, else the image segmentation.

Return type

tuple

SBD

class torchvision.datasets.SBDataset(root, image_set='train', mode='boundaries', download=False, transforms=None)[source]

Semantic Boundaries Dataset

The SBD currently contains annotations from 11355 images taken from the PASCAL VOC 2011 dataset.

Note

Please note that the train and val splits included with this dataset are different from the splits in the PASCAL VOC dataset. In particular some “train” images might be part of VOC2012 val. If you are interested in testing on VOC 2012 val, then use image_set=’train_noval’, which excludes all val images.

Warning

This class needs scipy to load target files from .mat format.

Parameters
  • root (string) – Root directory of the Semantic Boundaries Dataset

  • image_set (string, optional) – Select the image_set to use, train, val or train_noval. Image set train_noval excludes VOC 2012 val images.

  • mode (string, optional) – Select target type. Possible values ‘boundaries’ or ‘segmentation’. In case of ‘boundaries’, the target is an array of shape [num_classes, H, W], where num_classes=20.

  • download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.

  • xy_transform (callable, optional) – A function/transform that takes input sample and its target as entry and returns a transformed version. Input sample is PIL image and target is a numpy array if mode=’boundaries’ or PIL image if mode=’segmentation’.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources