Samplers#
Samplers control the order in which samples are accessed from a dataset. They determine the indices that the DataLoader uses to fetch data.
Sampler Base Class#
-
template<typename BatchRequest = std::vector<size_t>>
class Sampler# A
Sampleris an object that yields an index with which to access a dataset.Subclassed by torch::data::samplers::RandomSampler, torch::data::samplers::SequentialSampler
Public Types
-
using BatchRequestType = BatchRequest#
Public Functions
-
virtual ~Sampler() = default#
-
virtual void reset(std::optional<size_t> new_size) = 0#
Resets the
Sampler’s internal state.Typically called before a new epoch. Optionally, accepts a new size when resetting the sampler.
-
virtual std::optional<BatchRequest> next(size_t batch_size) = 0#
Returns the next index if possible, or an empty optional if the sampler is exhausted for this epoch.
-
virtual void save(serialize::OutputArchive &archive) const = 0#
Serializes the
Samplerto thearchive.
-
virtual void load(serialize::InputArchive &archive) = 0#
Deserializes the
Samplerfrom thearchive.
-
using BatchRequestType = BatchRequest#
Sequential Sampler#
Accesses samples in order from 0 to N-1. Use this for evaluation or when order matters.
-
class SequentialSampler : public torch::data::samplers::Sampler<>#
A
Samplerthat returns indices sequentially.Public Functions
-
explicit SequentialSampler(size_t size)#
Creates a
SequentialSamplerthat will return indices in the range0...size - 1.
-
virtual void reset(std::optional<size_t> new_size = std::nullopt) override#
Resets the
SequentialSamplerto zero.
-
virtual std::optional<std::vector<size_t>> next(size_t batch_size) override#
Returns the next batch of indices.
-
virtual void save(serialize::OutputArchive &archive) const override#
Serializes the
SequentialSamplerto thearchive.
-
virtual void load(serialize::InputArchive &archive) override#
Deserializes the
SequentialSamplerfrom thearchive.
-
size_t index() const noexcept#
Returns the current index of the
SequentialSampler.
-
explicit SequentialSampler(size_t size)#
Random Sampler#
Accesses samples in random order. Use this for training to ensure the model sees samples in different orders each epoch.
-
class RandomSampler : public torch::data::samplers::Sampler<>#
A
Samplerthat returns random indices.Public Functions
-
explicit RandomSampler(int64_t size, Dtype index_dtype = torch::kInt64)#
Constructs a
RandomSamplerwith a size and dtype for the stored indices.The constructor will eagerly allocate all required indices, which is the sequence
0 ... size - 1.index_dtypeis the data type of the stored indices. You can change it to influence memory usage.
-
~RandomSampler() override#
-
virtual void reset(std::optional<size_t> new_size = std::nullopt) override#
Resets the
RandomSamplerto a new set of indices.
-
virtual std::optional<std::vector<size_t>> next(size_t batch_size) override#
Returns the next batch of indices.
-
virtual void save(serialize::OutputArchive &archive) const override#
Serializes the
RandomSamplerto thearchive.
-
virtual void load(serialize::InputArchive &archive) override#
Deserializes the
RandomSamplerfrom thearchive.
-
size_t index() const noexcept#
Returns the current index of the
RandomSampler.
-
explicit RandomSampler(int64_t size, Dtype index_dtype = torch::kInt64)#
Distributed Random Sampler#
For distributed training, ensures each process gets a different subset of the data without overlap.
-
class DistributedRandomSampler : public torch::data::samplers::DistributedSampler<>#
Select samples randomly.
The sampling order is shuffled at each
reset()call.Public Functions
-
DistributedRandomSampler(size_t size, size_t num_replicas = 1, size_t rank = 0, bool allow_duplicates = true)#
-
virtual void reset(std::optional<size_t> new_size = std::nullopt) override#
Resets the
DistributedRandomSamplerto a new set of indices.
-
virtual std::optional<std::vector<size_t>> next(size_t batch_size) override#
Returns the next batch of indices.
-
virtual void save(serialize::OutputArchive &archive) const override#
Serializes the
DistributedRandomSamplerto thearchive.
-
virtual void load(serialize::InputArchive &archive) override#
Deserializes the
DistributedRandomSamplerfrom thearchive.
-
size_t index() const noexcept#
Returns the current index of the
DistributedRandomSampler.
-
DistributedRandomSampler(size_t size, size_t num_replicas = 1, size_t rank = 0, bool allow_duplicates = true)#
Distributed Sampler (Base)#
-
template<typename BatchRequest = std::vector<size_t>>
class DistributedSampler : public torch::data::samplers::Sampler<std::vector<size_t>># A
Samplerthat selects a subset of indices to sample from and defines a sampling behavior.In a distributed setting, this selects a subset of the indices depending on the provided num_replicas and rank parameters. The
Samplerperforms a rounding operation based on theallow_duplicatesparameter to decide the local sample count.Subclassed by torch::data::samplers::DistributedRandomSampler, torch::data::samplers::DistributedSequentialSampler
Public Functions
-
inline DistributedSampler(size_t size, size_t num_replicas = 1, size_t rank = 0, bool allow_duplicates = true)#
-
inline void set_epoch(size_t epoch)#
Set the epoch for the current enumeration.
This can be used to alter the sample selection and shuffling behavior.
-
inline size_t epoch() const#
-
inline DistributedSampler(size_t size, size_t num_replicas = 1, size_t rank = 0, bool allow_duplicates = true)#
Distributed Sequential Sampler#
-
class DistributedSequentialSampler : public torch::data::samplers::DistributedSampler<>#
Select samples sequentially.
Public Functions
-
DistributedSequentialSampler(size_t size, size_t num_replicas = 1, size_t rank = 0, bool allow_duplicates = true)#
-
virtual void reset(std::optional<size_t> new_size = std::nullopt) override#
Resets the
DistributedSequentialSamplerto a new set of indices.
-
virtual std::optional<std::vector<size_t>> next(size_t batch_size) override#
Returns the next batch of indices.
-
virtual void save(serialize::OutputArchive &archive) const override#
Serializes the
DistributedSequentialSamplerto thearchive.
-
virtual void load(serialize::InputArchive &archive) override#
Deserializes the
DistributedSequentialSamplerfrom thearchive.
-
size_t index() const noexcept#
Returns the current index of the
DistributedSequentialSampler.
-
DistributedSequentialSampler(size_t size, size_t num_replicas = 1, size_t rank = 0, bool allow_duplicates = true)#
Stream Sampler#
-
class StreamSampler : public torch::data::samplers::Sampler<BatchSize>#
A sampler for (potentially infinite) streams of data.
The major feature of the
StreamSampleris that it does not return particular indices, but instead only the number of elements to fetch from the dataset. The dataset has to decide how to produce those elements.Public Functions
-
explicit StreamSampler(size_t epoch_size)#
Constructs the
StreamSamplerwith the number of individual examples that should be fetched until the sampler is exhausted.
-
virtual void reset(std::optional<size_t> new_size = std::nullopt) override#
Resets the internal state of the sampler.
-
virtual std::optional<BatchSize> next(size_t batch_size) override#
Returns a
BatchSizeobject with the number of elements to fetch in the next batch.This number is the minimum of the supplied
batch_sizeand the difference between theepoch_sizeand the current index. If theepoch_sizehas been reached, returns an empty optional.
-
virtual void save(serialize::OutputArchive &archive) const override#
Serializes the
StreamSamplerto thearchive.
-
virtual void load(serialize::InputArchive &archive) override#
Deserializes the
StreamSamplerfrom thearchive.
-
explicit StreamSampler(size_t epoch_size)#