Rate this Page

Program Listing for File data_shuttle.h#

Return to documentation for file (torch/csrc/api/include/torch/data/detail/data_shuttle.h)

#pragma once

#include <torch/data/detail/queue.h>
#include <torch/types.h>

#include <c10/util/Exception.h>
#include <optional>

#include <chrono>
#include <utility>

namespace torch::data::detail {

template <typename Job, typename Result>
class DataShuttle {
 public:
  void push_job(Job job) {
    new_jobs_.push(std::move(job));
    ++in_flight_jobs_;
  }

  void push_result(Result result) {
    results_.push(std::move(result));
  }

  Job pop_job() {
    return new_jobs_.pop();
  }

  std::optional<Result> pop_result(
      std::optional<std::chrono::milliseconds> timeout = std::nullopt) {
    if (in_flight_jobs_ > 0) {
      auto result = results_.pop(timeout);
      --in_flight_jobs_;
      return result;
    }
    return std::nullopt;
  }

  void drain() {
    // Clear all inputs so that no further jobs are scheduled.
    auto number_cleared = new_jobs_.clear();
    in_flight_jobs_ -= number_cleared;
    // Remove any outstanding results.
    while (in_flight_jobs_ > 0) {
      pop_result();
    }
  }

  size_t in_flight_jobs() const noexcept {
    return in_flight_jobs_;
  }

 private:
  Queue<Job> new_jobs_;
  size_t in_flight_jobs_ = 0;
  Queue<Result> results_;
};

} // namespace torch::data::detail