Rate this Page

Program Listing for File mnist.h#

Return to documentation for file (torch/csrc/api/include/torch/data/datasets/mnist.h)

#pragma once

#include <torch/data/datasets/base.h>
#include <torch/data/example.h>
#include <torch/types.h>

#include <torch/csrc/Export.h>

#include <cstddef>
#include <string>

namespace torch::data::datasets {
class TORCH_API MNIST : public Dataset<MNIST> {
 public:
  enum class Mode { kTrain, kTest };

  explicit MNIST(const std::string& root, Mode mode = Mode::kTrain);

  Example<> get(size_t index) override;

  std::optional<size_t> size() const override;

  // NOLINTNEXTLINE(bugprone-exception-escape)
  bool is_train() const noexcept;

  const Tensor& images() const;

  const Tensor& targets() const;

 private:
  Tensor images_, targets_;
};
} // namespace torch::data::datasets