Program Listing for File data_shuttle.h#
↰ Return to documentation for file (torch/csrc/api/include/torch/data/detail/data_shuttle.h)
#pragma once
#include <torch/data/detail/queue.h>
#include <torch/types.h>
#include <c10/util/Exception.h>
#include <optional>
#include <chrono>
#include <utility>
namespace torch::data::detail {
template <typename Job, typename Result>
class DataShuttle {
public:
void push_job(Job job) {
new_jobs_.push(std::move(job));
++in_flight_jobs_;
}
void push_result(Result result) {
results_.push(std::move(result));
}
Job pop_job() {
return new_jobs_.pop();
}
std::optional<Result> pop_result(
std::optional<std::chrono::milliseconds> timeout = std::nullopt) {
if (in_flight_jobs_ > 0) {
auto result = results_.pop(timeout);
--in_flight_jobs_;
return result;
}
return std::nullopt;
}
void drain() {
// Clear all inputs so that no further jobs are scheduled.
auto number_cleared = new_jobs_.clear();
in_flight_jobs_ -= number_cleared;
// Remove any outstanding results.
while (in_flight_jobs_ > 0) {
pop_result();
}
}
size_t in_flight_jobs() const noexcept {
return in_flight_jobs_;
}
private:
Queue<Job> new_jobs_;
size_t in_flight_jobs_ = 0;
Queue<Result> results_;
};
} // namespace torch::data::detail