Rate this Page

Program Listing for File distributed.h#

Return to documentation for file (torch/csrc/api/include/torch/data/samplers/distributed.h)

#pragma once

#include <torch/csrc/Export.h>
#include <torch/data/samplers/base.h>

#include <cstddef>
#include <vector>

namespace torch::serialize {
class OutputArchive;
class InputArchive;
} // namespace torch::serialize

namespace torch::data::samplers {

template <typename BatchRequest = std::vector<size_t>>
class DistributedSampler : public Sampler<BatchRequest> {
 public:
  DistributedSampler(
      size_t size,
      size_t num_replicas = 1,
      size_t rank = 0,
      bool allow_duplicates = true)
      : size_(size),
        num_replicas_(num_replicas),
        rank_(rank),

        allow_duplicates_(allow_duplicates) {}

  void set_epoch(size_t epoch) {
    epoch_ = epoch;
  }

  size_t epoch() const {
    return epoch_;
  }

 protected:
  size_t local_sample_count() {
    if (allow_duplicates_) {
      return (size_ + num_replicas_ - 1) / num_replicas_;
    } else {
      return size_ / num_replicas_;
    }
  }

  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  size_t size_;
  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  size_t num_replicas_;
  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  size_t rank_;
  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  size_t epoch_{0};
  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  bool allow_duplicates_;
};

class TORCH_API DistributedRandomSampler : public DistributedSampler<> {
 public:
  DistributedRandomSampler(
      size_t size,
      size_t num_replicas = 1,
      size_t rank = 0,
      bool allow_duplicates = true);

  void reset(std::optional<size_t> new_size = std::nullopt) override;

  std::optional<std::vector<size_t>> next(size_t batch_size) override;

  void save(serialize::OutputArchive& archive) const override;

  void load(serialize::InputArchive& archive) override;

  size_t index() const noexcept;

 private:
  void populate_indices();

  size_t begin_index_{0};
  size_t end_index_{0};
  size_t sample_index_{0};
  std::vector<size_t> all_indices_;
};

class TORCH_API DistributedSequentialSampler : public DistributedSampler<> {
 public:
  DistributedSequentialSampler(
      size_t size,
      size_t num_replicas = 1,
      size_t rank = 0,
      bool allow_duplicates = true);

  void reset(std::optional<size_t> new_size = std::nullopt) override;

  std::optional<std::vector<size_t>> next(size_t batch_size) override;

  void save(serialize::OutputArchive& archive) const override;

  void load(serialize::InputArchive& archive) override;

  size_t index() const noexcept;

 private:
  void populate_indices();

  size_t begin_index_{0};
  size_t end_index_{0};
  size_t sample_index_{0};
  std::vector<size_t> all_indices_;
};

} // namespace torch::data::samplers