Program Listing for File chunk.h#
↰ Return to documentation for file (torch/csrc/api/include/torch/data/datasets/chunk.h)
#pragma once
#include <c10/util/irange.h>
#include <torch/arg.h>
#include <torch/data/datasets/stateful.h>
#include <torch/data/samplers.h>
#include <queue>
#include <thread>
#include <utility>
#include <torch/serialize.h>
namespace torch::data::datasets {
template <
typename ExampleType_,
typename ChunkType_ = std::vector<ExampleType_>>
class ChunkDataReader {
public:
virtual ~ChunkDataReader() = default;
using ChunkType = ChunkType_;
using ExampleType = ExampleType_;
virtual ChunkType read_chunk(size_t chunk_index) = 0;
virtual size_t chunk_count() = 0;
virtual void reset() = 0;
};
namespace detail {
template <
typename UnwrappedBatch,
typename ExampleSampler = samplers::RandomSampler>
class BatchDataBuffer {
public:
using UnwrappedBatchType = UnwrappedBatch;
using BatchType = std::optional<UnwrappedBatchType>;
using BatchRequestType = typename ExampleSampler::BatchRequestType;
BatchDataBuffer(
size_t batch_size,
ExampleSampler& example_sampler,
size_t queue_capacity)
: batch_size_(batch_size),
example_sampler_(example_sampler),
queue_capacity_(queue_capacity) {}
BatchType get_batch() {
std::unique_lock<std::mutex> lock(queue_mutex_);
cv_read_.wait(lock, [this] {
// wait till there is available data in the queue or if all chunks are
// loaded (i.e. the dataset is exhausted for this epoch)
return (
this->total_example_count_in_queue_ >= batch_size_ || this->stop_);
});
if (batch_queue_.empty()) {
AT_ASSERT(stop_);
// All batches have been retrieved. Return an empty batch.
return std::nullopt;
}
UnwrappedBatchData batch = std::move(batch_queue_.front());
batch_queue_.pop();
if (batch.exception) {
throw WorkerException(batch.exception);
}
total_example_count_in_queue_ -= batch.batch_data.size();
lock.unlock();
cv_write_.notify_all();
return batch.batch_data;
}
void add_chunk_data(UnwrappedBatchType data) {
std::unique_lock<std::mutex> lock(queue_mutex_);
cv_write_.wait(lock, [this] {
// stop loading if we have preloaded enough data.
return this->total_example_count_in_queue_ < this->queue_capacity_ ||
this->stop_;
});
if (stop_) {
// When stop_ is true, it means no further chunk loading is necessary.
// Return without any further processing.
return;
}
auto data_size = data.size();
auto remaining_size = data_size;
example_sampler_.reset(data_size);
auto fill_batch = [&](size_t example_count, UnwrappedBatchType& batch) {
auto batch_example_indices = this->example_sampler_.next(example_count);
AT_ASSERT(
batch_example_indices &&
batch_example_indices.value().size() == example_count);
BatchRequestType& indices = batch_example_indices.value();
for (size_t i : indices) {
TORCH_CHECK(i < data_size, "Index out of range");
batch.emplace_back(std::move(data[i]));
}
remaining_size -= example_count;
};
if (!batch_queue_.empty()) {
// if the queue has existing data, and the last batch doesn't have enough
// examples to fill a batch_size batch, add more example to this batch
// first.
auto& batch = batch_queue_.back();
size_t current_count = batch.batch_data.size();
if (current_count < batch_size_) {
auto example_count =
std::min(remaining_size, batch_size_ - current_count);
fill_batch(example_count, batch.batch_data);
}
}
// If we still have data remaining after filling the last pushed batch, add
// them to the queue too.
while (remaining_size > 0) {
UnwrappedBatchType current_batch;
// Allocate the batch memory ahead of time.
current_batch.reserve(batch_size_);
auto example_count = std::min(remaining_size, batch_size_);
fill_batch(example_count, current_batch);
batch_queue_.emplace(std::move(current_batch));
}
total_example_count_in_queue_ += data_size;
lock.unlock();
cv_read_.notify_all();
}
void add_chunk_data(std::exception_ptr e_ptr) {
std::unique_lock<std::mutex> lock(queue_mutex_);
cv_write_.wait(lock, [this] {
// stop loading if we have preloaded enough data.
return (
this->total_example_count_in_queue_ < this->queue_capacity_ ||
this->stop_);
});
if (stop_) {
// When stop_ is true, it means this current thread needs to be tore down,
// the batch buffer will be discarded, so no need to enqueue any new
// exceptions.
return;
}
batch_queue_.emplace(e_ptr);
lock.unlock();
cv_read_.notify_all();
}
void stop() {
{
// Hold the lock before changing stop_ to prevent a race condition which
// can cause a deadlock. To be more specific, conditional variable
// cv_write_ waits on predicate stop_ in add_chunk_data(). The wait
// happens in two steps: 1) while still holding the lock, check if
// predicate is true; 2) if it is true, proceeds, otherwise, release the
// lock and wait until notified. Without holding a lock, cv_write_'s
// notification can happen in between step 1) and 2). In that case, as
// cv_write_ is not in waiting status yet, so the notification is lost and
// cv_write_ will sleep forever. By taking a lock before changing
// predicate stop_, it is ensured updating and evaluating stop_ always
// happen in a synchronized way
std::lock_guard<std::mutex> lock(queue_mutex_);
stop_ = true;
}
// notify all writers, wake them from wait to exit current method.
cv_write_.notify_all();
// notify all readers too.
cv_read_.notify_all();
}
size_t batch_size_ = 0;
size_t total_example_count_in_queue_ = 0;
struct UnwrappedBatchData {
explicit UnwrappedBatchData(UnwrappedBatchType data)
: batch_data(std::move(data)) {}
explicit UnwrappedBatchData(std::exception_ptr e)
: exception(std::move(e)) {}
UnwrappedBatchType batch_data;
std::exception_ptr exception;
};
std::queue<UnwrappedBatchData> batch_queue_;
// sync batch_queue_ update.
std::mutex queue_mutex_;
std::condition_variable cv_read_;
std::condition_variable cv_write_;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
ExampleSampler& example_sampler_;
// configurable maximum number of elements the queue can hold at one time.
size_t queue_capacity_;
// When set to true, it wakes the writer threads from the wait and exit
// current function call. This is needed when ChunkDataSet.Reset is called
// while the previous epoch is not exhausted yet. When ChunkDataset is waiting
// its preloader to finish previous work before tearing down the thread, the
// preloader could be still waiting for the conditional variable, thus cause
// the program to hang. This boolean is used to break this waiting condition.
bool stop_ = false;
};
} // namespace detail
struct ChunkDatasetOptions {
ChunkDatasetOptions() = delete;
ChunkDatasetOptions(
size_t preloader_count,
size_t batch_size,
size_t cache_size = 2048,
size_t cross_chunk_shuffle_count = 1)
: preloader_count_(preloader_count),
batch_size_(batch_size),
cache_size_(cache_size),
cross_chunk_shuffle_count_(cross_chunk_shuffle_count) {
TORCH_CHECK(
preloader_count_ > 0,
"Preloader count is 0. At least one preloader needs to be specified.");
TORCH_CHECK(
batch_size_ > 0,
"Batch size is 0. A positive batch size needs to be specified.");
TORCH_CHECK(
cache_size_ > 0,
"Cache size is 0. A positive cache size needs to be specified.");
TORCH_CHECK(
cache_size_ >= batch_size_,
"Cache size is less than batch size. Cache needs to be large enough to "
"hold at least one batch.");
TORCH_CHECK(
cross_chunk_shuffle_count_ > 0,
"cross_chunk_shuffle_count needs to be greater than 0.");
}
TORCH_ARG(size_t, preloader_count);
TORCH_ARG(size_t, batch_size);
TORCH_ARG(size_t, cache_size) = 2048;
// The number of chunks to perform cross-chunk shuffling. Default to 1 meaning
// no cross-chunk shuffling. When it is equal to n (n > 1), n random
// chunks will be loaded at once and example shuffling will be performed
// across all those n chunks.
// Note: Usually the default config (1 chunk shuffle + example shuffle) is
// good enough to generate random distributed data. Use this parameter only if
// you know cross-shuffle is needed in your case. Also there is a performance
// penalty when this value is greater than 1, as we need to do extra merge
// between multiple chunks before performing example sampling.
TORCH_ARG(size_t, cross_chunk_shuffle_count) = 1;
};
template <
typename ChunkReader,
typename ChunkSampler = samplers::RandomSampler,
typename ExampleSampler = samplers::RandomSampler>
class ChunkDataset final
: public StatefulDataset<
ChunkDataset<ChunkReader, ChunkSampler, ExampleSampler>,
typename ChunkReader::BatchType,
size_t> {
public:
using BatchType = std::optional<typename ChunkReader::BatchType>;
using UnwrappedBatchType = typename ChunkReader::BatchType;
using BatchRequestType = size_t;
using ChunkSamplerType = ChunkSampler;
using ExampleSamplerType = ExampleSampler;
ChunkDataset(
ChunkReader chunk_reader,
ChunkSampler chunk_sampler,
ExampleSampler example_sampler,
ChunkDatasetOptions options,
std::function<void(UnwrappedBatchType&)> preprocessing_policy =
std::function<void(UnwrappedBatchType&)>())
: chunk_reader_(std::move(chunk_reader)),
chunk_sampler_(std::move(chunk_sampler)),
example_sampler_(std::move(example_sampler)),
options_(options),
preprocessing_policy_(std::move(preprocessing_policy)),
quit_worker_(false),
running_preloaders_(0) {}
~ChunkDataset() override {
// stop batch buffer first.
if (batch_buffer_) {
batch_buffer_->stop();
}
free_workers();
}
BatchType get_batch(size_t batch_size) override {
TORCH_CHECK(
batch_buffer_ != nullptr,
"Dataset needs to call reset() before calling get_batch().");
TORCH_CHECK(
batch_size == options_.batch_size(),
"The requested batch size does not match with the initialized batch size.\n"
" The requested batch size is ",
batch_size,
", while the dataset is created with batch size equal to ",
options_.batch_size());
return batch_buffer_->get_batch();
}
BatchType get_batch() {
return get_batch(options_.batch_size());
}
void reset() override {
// We need this to support partial data reads via dataloader iterator.
if (batch_buffer_) {
batch_buffer_->stop();
}
// free workers from previous reset if there is any.
free_workers();
preload_threads_.clear();
if (!load_checkpoint_) {
chunk_reader_.reset();
chunk_sampler_.reset(chunk_reader_.chunk_count());
load_checkpoint_ = false;
}
// Throw out any existing cached batch in the buffer and re-creates a new
// chunk buffer.
batch_buffer_ = std::make_unique<
detail::BatchDataBuffer<UnwrappedBatchType, ExampleSamplerType>>(
options_.batch_size(), example_sampler_, options_.cache_size());
// create new workers for this new epoch.
quit_worker_ = false;
AT_ASSERT(running_preloaders_ == 0);
running_preloaders_ = options_.preloader_count();
for (const auto i : c10::irange(options_.preloader_count())) {
preload_threads_.emplace_back([this, i]() { this->preloader(i); });
}
}
std::optional<size_t> size() const override {
return std::nullopt;
}
// provide a references to chunk sampler. Used mainly in distributed data
// loading to set the epoch number for the sampler.
ChunkSamplerType& chunk_sampler() {
return chunk_sampler_;
}
void save(serialize::OutputArchive& archive) const override {
std::lock_guard<std::mutex> lock(chunk_index_guard_);
chunk_sampler_.save(archive);
}
void load(serialize::InputArchive& archive) override {
std::lock_guard<std::mutex> lock(chunk_index_guard_);
chunk_sampler_.load(archive);
load_checkpoint_ = true;
}
private:
void preloader(size_t id) {
while (!quit_worker_.load()) {
try {
std::vector<size_t> chunk_idx;
{
std::lock_guard<std::mutex> lock(chunk_index_guard_);
if (auto chunk_sampler_result = chunk_sampler_.next(
this->options_.cross_chunk_shuffle_count())) {
chunk_idx = chunk_sampler_result.value();
} else {
break;
}
}
UnwrappedBatchType data = chunk_reader_.read_chunk(chunk_idx[0]);
for (const auto i : c10::irange(1, chunk_idx.size())) {
auto chunk_data = chunk_reader_.read_chunk(chunk_idx[i]);
std::move(
chunk_data.begin(), chunk_data.end(), std::back_inserter(data));
}
if (preprocessing_policy_) {
preprocessing_policy_(data);
}
if (!data.empty()) { // skip empty chunks.
batch_buffer_->add_chunk_data(std::move(data));
}
} catch (...) {
batch_buffer_->add_chunk_data(std::current_exception());
}
}
AT_ASSERT(running_preloaders_.load() > 0);
--running_preloaders_;
if (running_preloaders_.load() == 0) {
// all preloaders are completed, so we can notify the batch_buffer.
batch_buffer_->stop();
}
}
void free_workers() {
if (!quit_worker_.load()) {
quit_worker_ = true;
for (auto& worker_thread : preload_threads_) {
worker_thread.join();
}
}
}
private:
// Templated class that defines what is a chunk and how to read chunk data.
// When a chunk is returned by chunk_reader_, ChunkDataset split it into
// batches and caches them in batch_buffer_.
ChunkReader chunk_reader_;
// chunk sampler to shuffle different chunks
ChunkSamplerType chunk_sampler_;
// example sampler to shuffle examples in a specific chunk
ExampleSamplerType example_sampler_;
// batch data buffer which holds chunk data from preloading thread.
std::shared_ptr<
detail::BatchDataBuffer<UnwrappedBatchType, ExampleSamplerType>>
batch_buffer_;
// worker thread pool
std::vector<std::thread> preload_threads_;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const ChunkDatasetOptions options_;
// function pointer wrapper to apply custom processing over chunk data. This
// is considered an advanced parameter for developers who want to apply a
// pre-process to the chunk data before sampling into minibatch.
// Different than the collate function, this policy is applied on the chunk
// level, instead of minibatch level. When a chunk of data is loaded (multiple
// chunks if cross_chunk_shuffle_count_ is greater than 1), this policy is
// applied to the full loaded data. It is useful if developers want to
// perform pre-processing (like bucketing) to the chunk data before
// example sampler samples the data. By default it's an empty pointer and no
// action will be taken.
std::function<void(UnwrappedBatchType&)> preprocessing_policy_;
// indicate whether the worker thread can be teared down
std::atomic<bool> quit_worker_;
// keep track of running preloaders to notify batch buffer. A value 0
// indicates that the chunk loading is completed.
std::atomic<size_t> running_preloaders_;
// mutex to synchronize chunk sampler next() call.
mutable std::mutex chunk_index_guard_;
// boolean value to indicate whether we need to load the checkpoint for
// chunk_sampler_.
bool load_checkpoint_{false};
};
} // namespace torch::data::datasets