Program Listing for File map.h#
↰ Return to documentation for file (torch/csrc/api/include/torch/data/datasets/map.h)
#pragma once
#include <torch/data/datasets/base.h>
#include <torch/types.h>
#include <c10/util/ArrayRef.h>
#include <cstddef>
#include <type_traits>
#include <utility>
namespace torch::data::datasets {
namespace detail {
template <bool C, typename T>
using optional_if_t = std::conditional_t<C, std::optional<T>, T>;
} // namespace detail
template <typename SourceDataset, typename AppliedTransform>
class MapDataset : public BatchDataset<
MapDataset<SourceDataset, AppliedTransform>,
detail::optional_if_t<
SourceDataset::is_stateful,
typename AppliedTransform::OutputBatchType>,
typename SourceDataset::BatchRequestType> {
public:
using DatasetType = SourceDataset;
using TransformType = AppliedTransform;
using BatchRequestType = typename SourceDataset::BatchRequestType;
using OutputBatchType = detail::optional_if_t<
SourceDataset::is_stateful,
typename AppliedTransform::OutputBatchType>;
MapDataset(DatasetType dataset, TransformType transform)
: dataset_(std::move(dataset)), transform_(std::move(transform)) {}
OutputBatchType get_batch(BatchRequestType indices) override {
return get_batch_impl(std::move(indices));
}
// NOLINTNEXTLINE(bugprone-exception-escape)
std::optional<size_t> size() const noexcept override {
return dataset_.size();
}
void reset() {
dataset_.reset();
}
const SourceDataset& dataset() noexcept {
return dataset_;
}
const AppliedTransform& transform() noexcept {
return transform_;
}
private:
template <
typename D = SourceDataset,
typename = std::enable_if_t<!D::is_stateful>>
OutputBatchType get_batch_impl(BatchRequestType indices) {
return transform_.apply_batch(dataset_.get_batch(std::move(indices)));
}
template <typename D = SourceDataset>
std::enable_if_t<D::is_stateful, OutputBatchType> get_batch_impl(
BatchRequestType indices) {
if (auto batch = dataset_.get_batch(std::move(indices))) {
return transform_.apply_batch(std::move(*batch));
}
return std::nullopt;
}
SourceDataset dataset_;
// The transformation that is applied to batches received from the dataset.
AppliedTransform transform_;
};
template <typename DatasetType, typename TransformType>
MapDataset<DatasetType, TransformType> map(
DatasetType dataset,
TransformType transform) {
static_assert(
std::is_same_v<
std::conditional_t<
DatasetType::is_stateful,
typename DatasetType::BatchType::value_type,
typename DatasetType::BatchType>,
typename TransformType::InputBatchType>,
"BatchType type of dataset does not match input type of transform");
return {std::move(dataset), std::move(transform)};
}
} // namespace torch::data::datasets