Shortcuts

torch.utils.data

class torch.utils.data.Dataset[source]

An abstract class representing a Dataset.

All other datasets should subclass it. All subclasses should override __len__, that provides the size of the dataset, and __getitem__, supporting integer indexing in range from 0 to len(self) exclusive.

class torch.utils.data.TensorDataset(*tensors)[source]

Dataset wrapping tensors.

Each sample will be retrieved by indexing tensors along the first dimension.

Parameters

*tensors (Tensor) – tensors that have the same size of the first dimension.

class torch.utils.data.ConcatDataset(datasets)[source]

Dataset to concatenate multiple datasets. Purpose: useful to assemble different existing datasets, possibly large-scale datasets as the concatenation operation is done in an on-the-fly manner.

Parameters

datasets (sequence) – List of datasets to be concatenated

class torch.utils.data.Subset(dataset, indices)[source]

Subset of a dataset at specified indices.

Parameters
  • dataset (Dataset) – The whole Dataset

  • indices (sequence) – Indices in the whole set selected for subset

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)[source]

Data loader. Combines a dataset and a sampler, and provides single- or multi-process iterators over the dataset.

Parameters
  • dataset (Dataset) – dataset from which to load the data.

  • batch_size (int, optional) – how many samples per batch to load (default: 1).

  • shuffle (bool, optional) – set to True to have the data reshuffled at every epoch (default: False).

  • sampler (Sampler, optional) – defines the strategy to draw samples from the dataset. If specified, shuffle must be False.

  • batch_sampler (Sampler, optional) – like sampler, but returns a batch of indices at a time. Mutually exclusive with batch_size, shuffle, sampler, and drop_last.

  • num_workers (int, optional) – how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0)

  • collate_fn (callable, optional) – merges a list of samples to form a mini-batch.

  • pin_memory (bool, optional) – If True, the data loader will copy tensors into CUDA pinned memory before returning them. If your data elements are a custom type, or your collate_fn returns a batch that is a custom type see the example below.

  • drop_last (bool, optional) – set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: False)

  • timeout (numeric, optional) – if positive, the timeout value for collecting a batch from workers. Should always be non-negative. (default: 0)

  • worker_init_fn (callable, optional) – If not None, this will be called on each worker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. (default: None)

Note

When num_workers != 0, the corresponding worker processes are created each time iterator for the DataLoader is obtained (as in when you call enumerate(dataloader,0)). At this point, the dataset, collate_fn and worker_init_fn are passed to each worker, where they are used to access and initialize data based on the indices queued up from the main process. This means that dataset access together with its internal IO, transforms and collation runs in the worker, while any shuffle randomization is done in the main process which guides loading by assigning indices to load. Workers are shut down once the end of the iteration is reached.

Since workers rely on Python multiprocessing, worker launch behavior is different on Windows compared to Unix. On Unix fork() is used as the default muliprocessing start method, so child workers typically can access the dataset and Python argument functions directly through the cloned address space. On Windows, another interpreter is launched which runs your main script, followed by the internal worker function that receives the dataset, collate_fn and other arguments through Pickle serialization.

This separate serialization means that you should take two steps to ensure you are compatible with Windows while using workers (this also works equally well on Unix):

  • Wrap most of you main script’s code within if __name__ == '__main__': block, to make sure it doesn’t run again (most likely generating error) when each worker process is launched. You can place your dataset and DataLoader instance creation logic here, as it doesn’t need to be re-executed in workers.

  • Make sure that collate_fn, worker_init_fn or any custom dataset code is declared as a top level def, outside of that __main__ check. This ensures they are available in workers as well (this is needed since functions are pickled as references only, not bytecode).

By default, each worker will have its PyTorch seed set to base_seed + worker_id, where base_seed is a long generated by main process using its RNG. However, seeds for other libraies may be duplicated upon initializing workers (w.g., NumPy), causing each worker to return identical random numbers. (See My data loader workers return identical random numbers section in FAQ.) You may use torch.initial_seed() to access the PyTorch seed for each worker in worker_init_fn, and use it to set other seeds before data loading.

Warning

If spawn start method is used, worker_init_fn cannot be an unpicklable object, e.g., a lambda function.

The default memory pinning logic only recognizes Tensors and maps and iterables containg Tensors. By default, if the pinning logic sees a batch that is a custom type (which will occur if you have a collate_fn that returns a custom batch type), or if each element of your batch is a custom type, the pinning logic will not recognize them, and it will return that batch (or those elements) without pinning the memory. To enable memory pinning for custom batch or data types, define a pin_memory method on your custom type(s).

Example:

class SimpleCustomBatch:
    def __init__(self, data):
        transposed_data = list(zip(*data))
        self.inp = torch.stack(transposed_data[0], 0)
        self.tgt = torch.stack(transposed_data[1], 0)

    def pin_memory(self):
        self.inp = self.inp.pin_memory()
        self.tgt = self.tgt.pin_memory()
        return self

def collate_wrapper(batch):
    return SimpleCustomBatch(batch)

inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
dataset = TensorDataset(inps, tgts)

loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper,
                    pin_memory=True)

for batch_ndx, sample in enumerate(loader):
    print(sample.inp.is_pinned())
    print(sample.tgt.is_pinned())
torch.utils.data.random_split(dataset, lengths)[source]

Randomly split a dataset into non-overlapping new datasets of given lengths.

Parameters
  • dataset (Dataset) – Dataset to be split

  • lengths (sequence) – lengths of splits to be produced

class torch.utils.data.Sampler(data_source)[source]

Base class for all Samplers.

Every Sampler subclass has to provide an __iter__ method, providing a way to iterate over indices of dataset elements, and a __len__ method that returns the length of the returned iterators.

class torch.utils.data.SequentialSampler(data_source)[source]

Samples elements sequentially, always in the same order.

Parameters

data_source (Dataset) – dataset to sample from

class torch.utils.data.RandomSampler(data_source, replacement=False, num_samples=None)[source]

Samples elements randomly. If without replacement, then sample from a shuffled dataset. If with replacement, then user can specify num_samples to draw.

Parameters
  • data_source (Dataset) – dataset to sample from

  • replacement (bool) – samples are drawn with replacement if True, default=``False``

  • num_samples (int) – number of samples to draw, default=`len(dataset)`. This argument is supposed to be specified only when replacement is True.

class torch.utils.data.SubsetRandomSampler(indices)[source]

Samples elements randomly from a given list of indices, without replacement.

Parameters

indices (sequence) – a sequence of indices

class torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True)[source]

Samples elements from [0,..,len(weights)-1] with given probabilities (weights).

Parameters
  • weights (sequence) – a sequence of weights, not necessary summing up to one

  • num_samples (int) – number of samples to draw

  • replacement (bool) – if True, samples are drawn with replacement. If not, they are drawn without replacement, which means that when a sample index is drawn for a row, it cannot be drawn again for that row.

Example

>>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
[0, 0, 0, 1, 0]
>>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
[0, 1, 4, 3, 2]
class torch.utils.data.BatchSampler(sampler, batch_size, drop_last)[source]

Wraps another sampler to yield a mini-batch of indices.

Parameters
  • sampler (Sampler) – Base sampler.

  • batch_size (int) – Size of mini-batch.

  • drop_last (bool) – If True, the sampler will drop the last batch if its size would be less than batch_size

Example

>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
class torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None)[source]

Sampler that restricts data loading to a subset of the dataset.

It is especially useful in conjunction with torch.nn.parallel.DistributedDataParallel. In such case, each process can pass a DistributedSampler instance as a DataLoader sampler, and load a subset of the original dataset that is exclusive to it.

Note

Dataset is assumed to be of constant size.

Parameters
  • dataset – Dataset used for sampling.

  • num_replicas (optional) – Number of processes participating in distributed training.

  • rank (optional) – Rank of the current process within num_replicas.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources