Program Listing for File sequencers.h#
↰ Return to documentation for file (torch/csrc/api/include/torch/data/detail/sequencers.h)
#pragma once
#include <torch/types.h>
#include <algorithm>
#include <cstddef>
#include <vector>
namespace torch::data::detail::sequencers {
namespace detail {
template <typename Result>
bool buffer_contains_result(const std::vector<std::optional<Result>>& buffer) {
return std::any_of(
buffer.begin(), buffer.end(), [](const std::optional<Result>& result) {
return result.has_value();
});
}
} // namespace detail
template <typename Result>
struct Sequencer {
using ResultProducer = std::function<std::optional<Result>()>;
virtual ~Sequencer() = default;
virtual std::optional<Result> next(ResultProducer next_result) = 0;
};
template <typename Result>
struct NoSequencer final : public Sequencer<Result> {
using typename Sequencer<Result>::ResultProducer;
std::optional<Result> next(ResultProducer next_result) override {
return next_result();
}
};
template <typename Result>
struct OrderedSequencer : public Sequencer<Result> {
using typename Sequencer<Result>::ResultProducer;
explicit OrderedSequencer(size_t max_jobs) : buffer_(max_jobs) {}
std::optional<Result> next(ResultProducer next_result) override {
// If we already have the result for the next sqn, return it.
if (auto& maybe_result = buffer(next_sequence_number_)) {
auto result = std::move(*maybe_result);
buffer(next_sequence_number_++).reset();
return result;
}
// Otherwise wait for the next result.
while (true) {
auto result = next_result();
if (!result) {
AT_ASSERT(!detail::buffer_contains_result(buffer_));
break;
}
// If it was not nullopt and the sequence numbers match, return it
// directly and bump the sequence number.
if (result->sequence_number == next_sequence_number_) {
++next_sequence_number_;
return result;
}
// Stash the result for later.
AT_ASSERT(!buffer(result->sequence_number).has_value());
buffer(result->sequence_number) = std::move(result);
}
// The result was an empty optional, so we are done with this epoch.
return std::nullopt;
}
std::optional<Result>& buffer(size_t index) {
return buffer_.at(index % buffer_.size());
}
size_t next_sequence_number_ = 0;
std::vector<std::optional<Result>> buffer_;
};
} // namespace torch::data::detail::sequencers