Rate this Page

Program Listing for File sequencers.h#

Return to documentation for file (torch/csrc/api/include/torch/data/detail/sequencers.h)

#pragma once

#include <torch/types.h>

#include <algorithm>
#include <cstddef>
#include <vector>

namespace torch::data::detail::sequencers {
namespace detail {
template <typename Result>
bool buffer_contains_result(const std::vector<std::optional<Result>>& buffer) {
  return std::any_of(
      buffer.begin(), buffer.end(), [](const std::optional<Result>& result) {
        return result.has_value();
      });
}
} // namespace detail

template <typename Result>
struct Sequencer {
  using ResultProducer = std::function<std::optional<Result>()>;
  virtual ~Sequencer() = default;
  virtual std::optional<Result> next(ResultProducer next_result) = 0;
};

template <typename Result>
struct NoSequencer final : public Sequencer<Result> {
  using typename Sequencer<Result>::ResultProducer;
  std::optional<Result> next(ResultProducer next_result) override {
    return next_result();
  }
};

template <typename Result>
struct OrderedSequencer : public Sequencer<Result> {
  using typename Sequencer<Result>::ResultProducer;

  explicit OrderedSequencer(size_t max_jobs) : buffer_(max_jobs) {}

  std::optional<Result> next(ResultProducer next_result) override {
    // If we already have the result for the next sqn, return it.
    if (auto& maybe_result = buffer(next_sequence_number_)) {
      auto result = std::move(*maybe_result);
      buffer(next_sequence_number_++).reset();
      return result;
    }
    // Otherwise wait for the next result.
    while (true) {
      auto result = next_result();
      if (!result) {
        AT_ASSERT(!detail::buffer_contains_result(buffer_));
        break;
      }
      // If it was not nullopt and the sequence numbers match, return it
      // directly and bump the sequence number.
      if (result->sequence_number == next_sequence_number_) {
        ++next_sequence_number_;
        return result;
      }
      // Stash the result for later.
      AT_ASSERT(!buffer(result->sequence_number).has_value());
      buffer(result->sequence_number) = std::move(result);
    }
    // The result was an empty optional, so we are done with this epoch.
    return std::nullopt;
  }

  std::optional<Result>& buffer(size_t index) {
    return buffer_.at(index % buffer_.size());
  }

  size_t next_sequence_number_ = 0;

  std::vector<std::optional<Result>> buffer_;
};
} // namespace torch::data::detail::sequencers