Rate this Page

Program Listing for File distance.h#

Return to documentation for file (torch/csrc/api/include/torch/nn/modules/distance.h)

#pragma once

#include <torch/nn/cloneable.h>
#include <torch/nn/functional/distance.h>
#include <torch/nn/options/distance.h>
#include <torch/nn/pimpl.h>
#include <torch/types.h>

#include <torch/csrc/Export.h>

namespace torch::nn {

class TORCH_API CosineSimilarityImpl : public Cloneable<CosineSimilarityImpl> {
 public:
  explicit CosineSimilarityImpl(const CosineSimilarityOptions& options_ = {});

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  Tensor forward(const Tensor& input1, const Tensor& input2);

  CosineSimilarityOptions options;
};

TORCH_MODULE(CosineSimilarity);

// ============================================================================

class TORCH_API PairwiseDistanceImpl : public Cloneable<PairwiseDistanceImpl> {
 public:
  explicit PairwiseDistanceImpl(const PairwiseDistanceOptions& options_ = {});

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  Tensor forward(const Tensor& input1, const Tensor& input2);

  PairwiseDistanceOptions options;
};

TORCH_MODULE(PairwiseDistance);

} // namespace torch::nn