Program Listing for File base.h#
↰ Return to documentation for file (torch/csrc/api/include/torch/data/datasets/base.h)
#pragma once
#include <torch/data/example.h>
#include <torch/types.h>
#include <c10/util/ArrayRef.h>
#include <cstddef>
#include <cstdint>
#include <type_traits>
#include <utility>
#include <vector>
namespace torch::data::datasets {
template <typename S, typename T>
class MapDataset;
template <typename D, typename T>
MapDataset<D, T> map(D, T); // NOLINT
} // namespace torch::data::datasets
namespace torch::data::datasets {
namespace detail {
template <typename T>
struct is_optional : std::false_type {};
template <typename T>
struct is_optional<std::optional<T>> : std::true_type {};
} // namespace detail
template <
typename Self,
typename Batch = std::vector<Example<>>,
typename BatchRequest = ArrayRef<size_t>>
class BatchDataset {
public:
using SelfType = Self;
using BatchType = Batch;
using BatchRequestType = BatchRequest;
constexpr static bool is_stateful = detail::is_optional<BatchType>::value;
virtual ~BatchDataset() = default;
virtual Batch get_batch(BatchRequest request) = 0;
virtual std::optional<size_t> size() const = 0;
template <typename TransformType>
MapDataset<Self, TransformType> map(TransformType transform) & {
return datasets::map(static_cast<Self&>(*this), std::move(transform));
}
template <typename TransformType>
MapDataset<Self, TransformType> map(TransformType transform) && {
return datasets::map(
std::move(static_cast<Self&>(*this)), std::move(transform));
}
};
template <typename Self, typename SingleExample = Example<>>
class Dataset : public BatchDataset<Self, std::vector<SingleExample>> {
public:
using ExampleType = SingleExample;
virtual ExampleType get(size_t index) = 0;
std::vector<ExampleType> get_batch(ArrayRef<size_t> indices) override {
std::vector<ExampleType> batch;
batch.reserve(indices.size());
for (const auto i : indices) {
batch.push_back(get(i));
}
return batch;
}
};
template <typename Self, typename Batch = std::vector<Example<>>>
using StreamDataset = BatchDataset<Self, Batch, /*BatchRequest=*/size_t>;
} // namespace torch::data::datasets