Datasets#
The dataset abstraction defines how to access individual samples in your data.
All datasets inherit from Dataset and must implement get() and size().
Dataset Base Class#
-
template<typename Self, typename SingleExample = Example<>>
class Dataset : public torch::data::datasets::BatchDataset<Self, std::vector<Example<>>># A dataset that can yield data in batches, or as individual examples.
A
Datasetis aBatchDataset, because it supports random access and therefore batched access is implemented (by default) by calling the random access indexing function for each index in the requested batch of indices. This can be customized.Public Types
-
using ExampleType = SingleExample#
Public Functions
-
virtual ExampleType get(size_t index) = 0#
Returns the example at the given index.
-
inline virtual std::vector<ExampleType> get_batch(ArrayRef<size_t> indices) override#
Returns a batch of data.
The default implementation calls
get()for every requested index in the batch.
-
using ExampleType = SingleExample#
-
template<typename Self, typename Batch = std::vector<Example<>>, typename BatchRequest = ArrayRef<size_t>>
class BatchDataset# A dataset that can yield data only in batches.
Subclassed by torch::data::datasets::Dataset< MNIST >, torch::data::datasets::Dataset< TensorDataset, TensorExample >, torch::data::datasets::StatefulDataset< ChunkDataset< ChunkReader, samplers::RandomSampler, samplers::RandomSampler >, ChunkReader::BatchType, size_t >
Public Functions
-
virtual ~BatchDataset() = default#
-
virtual Batch get_batch(BatchRequest request) = 0#
Returns a batch of data given an index.
-
virtual std::optional<size_t> size() const = 0#
Returns the size of the dataset, or an empty std::optional if it is unsized.
-
template<typename TransformType>
inline MapDataset<Self, TransformType> map(TransformType transform) &# Creates a
MapDatasetthat applies the giventransformto this dataset.
-
template<typename TransformType>
inline MapDataset<Self, TransformType> map(TransformType transform) &&# Creates a
MapDatasetthat applies the giventransformto this dataset.
-
virtual ~BatchDataset() = default#
StatefulDataset#
A dataset that manages its own state across batches (e.g., position in a stream).
Unlike Dataset, it produces batches directly without external samplers.
-
template<typename Self, typename Batch = std::vector<Example<>>, typename BatchRequest = size_t>
class StatefulDataset : public BatchDataset<Self, std::optional<std::vector<Example<>>>, size_t># A stateful dataset is a dataset that maintains some internal state, which will be
reset()at the beginning of each epoch.Subclasses can override the
reset()method to configure this behavior. Further, the return type of a stateful dataset’sget_batch()method is always anoptional. When the stateful dataset wants to indicate to the dataloader that its epoch has ended, it should return an empty optional. The dataloader knows to modify its implementation based on whether the dataset is stateless or stateful.Note that when subclassing a from
StatefulDataset<Self, T>, the return type ofget_batch(), which the subclass must override, will beoptional<T>(i.e. the type specified in theStatefulDatasetspecialization is automatically boxed into anoptionalfor the dataset’sBatchType).Public Functions
-
virtual void reset() = 0#
Resets internal state of the dataset.
-
virtual void save(serialize::OutputArchive &archive) const = 0#
Saves the statefulDataset’s state to OutputArchive.
-
virtual void load(serialize::InputArchive &archive) = 0#
Deserializes the statefulDataset’s state from the
archive.
-
virtual void reset() = 0#
ChunkDataReader#
Interface for reading chunks of data from a data source. Used with
ChunkDataset for large-scale data loading.
-
template<typename ExampleType_, typename ChunkType_ = std::vector<ExampleType_>>
class ChunkDataReader# Interface for chunk reader, which performs data chunking and reading of entire chunks.
A chunk could be an entire file, such as an audio data file or an image, or part of a file in the case of a large text-file split based on seek positions.
Custom Dataset Example#
class CustomDataset : public torch::data::datasets::Dataset<CustomDataset> {
public:
explicit CustomDataset(const std::string& root) {
// Load data from root directory
}
torch::data::Example<> get(size_t index) override {
return {images_[index], labels_[index]};
}
torch::optional<size_t> size() const override {
return images_.size(0);
}
private:
torch::Tensor images_, labels_;
};
MapDataset#
-
template<typename SourceDataset, typename AppliedTransform>
class MapDataset : public torch::data::datasets::BatchDataset<MapDataset<SourceDataset, AppliedTransform>, detail::optional_if_t<SourceDataset::is_stateful, AppliedTransform::OutputBatchType>, SourceDataset::BatchRequestType># A
MapDatasetis a dataset that applies a transform to a source dataset.Public Types
-
using DatasetType = SourceDataset#
-
using TransformType = AppliedTransform#
-
using BatchRequestType = typename SourceDataset::BatchRequestType#
-
using OutputBatchType = detail::optional_if_t<SourceDataset::is_stateful, typename AppliedTransform::OutputBatchType>#
Public Functions
-
inline MapDataset(DatasetType dataset, TransformType transform)#
-
inline virtual OutputBatchType get_batch(BatchRequestType indices) override#
Gets a batch from the source dataset and applies the transform to it, returning the result.
-
inline virtual std::optional<size_t> size() const noexcept override#
Returns the size of the source dataset.
-
inline void reset()#
Calls
reset()on the underlying dataset.NOTE: Stateless datasets do not have a reset() method, so a call to this method will only compile for stateful datasets (which have a reset() method).
-
inline const SourceDataset &dataset() noexcept#
Returns the underlying dataset.
-
inline const AppliedTransform &transform() noexcept#
Returns the transform being applied.
-
using DatasetType = SourceDataset#
ChunkDataset#
-
template<typename ChunkReader, typename ChunkSampler = samplers::RandomSampler, typename ExampleSampler = samplers::RandomSampler>
class ChunkDataset : public torch::data::datasets::StatefulDataset<ChunkDataset<ChunkReader, samplers::RandomSampler, samplers::RandomSampler>, ChunkReader::BatchType, size_t># A stateful dataset that support hierarchical sampling and prefetching of entre chunks.
Unlike regular dataset, chunk dataset require two samplers to operate and keeps an internal state.
ChunkSamplerselects, which chunk to load next, while theExampleSamplerdetermines the order of Examples that are returned in eachget_batchcall. The hierarchical sampling approach used here is inspired by this paper http://martin.zinkevich.org/publications/nips2010.pdfPublic Types
-
using BatchType = std::optional<typename ChunkReader::BatchType>#
-
using UnwrappedBatchType = typename ChunkReader::BatchType#
-
using BatchRequestType = size_t#
-
using ChunkSamplerType = ChunkSampler#
-
using ExampleSamplerType = ExampleSampler#
Public Functions
-
inline ChunkDataset(ChunkReader chunk_reader, ChunkSampler chunk_sampler, ExampleSampler example_sampler, ChunkDatasetOptions options, std::function<void(UnwrappedBatchType&)> preprocessing_policy = std::function<void(UnwrappedBatchType&)>())#
-
inline ~ChunkDataset() override#
-
inline BatchType get_batch(size_t batch_size) override#
Default get_batch method of BatchDataset.
This method returns Example batches created from the preloaded chunks. The implementation is dataset agnostic and does not need overriding in different chunk datasets.
-
inline BatchType get_batch()#
Helper method around get_batch as
batch_sizeis not strictly necessary.
-
inline virtual void reset() override#
This will clear any internal state and starts the internal prefetching mechanism for the chunk dataset.
-
inline virtual std::optional<size_t> size() const override#
size is not used for chunk dataset.
-
inline ChunkSamplerType &chunk_sampler()#
-
inline virtual void save(serialize::OutputArchive &archive) const override#
Saves the statefulDataset’s state to OutputArchive.
-
inline virtual void load(serialize::InputArchive &archive) override#
Deserializes the statefulDataset’s state from the
archive.
-
using BatchType = std::optional<typename ChunkReader::BatchType>#
Built-in Datasets#
MNIST#
-
class MNIST : public torch::data::datasets::Dataset<MNIST>#
The MNIST dataset.
Public Types
Public Functions
-
explicit MNIST(const std::string &root, Mode mode = Mode::kTrain)#
Loads the MNIST dataset from the
rootpath.The supplied
rootpath should contain the content of the unzipped MNIST dataset, available from http://yann.lecun.com/exdb/mnist.
-
virtual std::optional<size_t> size() const override#
Returns the size of the dataset.
-
explicit MNIST(const std::string &root, Mode mode = Mode::kTrain)#
Example:
auto dataset = torch::data::datasets::MNIST("./data")
.map(torch::data::transforms::Normalize<>(0.1307, 0.3081))
.map(torch::data::transforms::Stack<>());