Program Listing for File embedding.h#
↰ Return to documentation for file (torch/csrc/api/include/torch/nn/options/embedding.h)
#pragma once
#include <torch/arg.h>
#include <torch/csrc/Export.h>
#include <torch/enum.h>
#include <torch/types.h>
namespace torch::nn {
struct TORCH_API EmbeddingOptions {
EmbeddingOptions(int64_t num_embeddings, int64_t embedding_dim);
TORCH_ARG(int64_t, num_embeddings);
TORCH_ARG(int64_t, embedding_dim);
TORCH_ARG(std::optional<int64_t>, padding_idx) = std::nullopt;
TORCH_ARG(std::optional<double>, max_norm) = std::nullopt;
TORCH_ARG(double, norm_type) = 2.;
TORCH_ARG(bool, scale_grad_by_freq) = false;
TORCH_ARG(bool, sparse) = false;
TORCH_ARG(torch::Tensor, _weight);
};
// ============================================================================
struct TORCH_API EmbeddingFromPretrainedOptions {
TORCH_ARG(bool, freeze) = true;
TORCH_ARG(std::optional<int64_t>, padding_idx) = std::nullopt;
TORCH_ARG(std::optional<double>, max_norm) = std::nullopt;
TORCH_ARG(double, norm_type) = 2.;
TORCH_ARG(bool, scale_grad_by_freq) = false;
TORCH_ARG(bool, sparse) = false;
};
// ============================================================================
namespace functional {
struct TORCH_API EmbeddingFuncOptions {
TORCH_ARG(std::optional<int64_t>, padding_idx) = std::nullopt;
TORCH_ARG(std::optional<double>, max_norm) = std::nullopt;
TORCH_ARG(double, norm_type) = 2.;
TORCH_ARG(bool, scale_grad_by_freq) = false;
TORCH_ARG(bool, sparse) = false;
};
} // namespace functional
// ============================================================================
typedef std::variant<enumtype::kSum, enumtype::kMean, enumtype::kMax>
EmbeddingBagMode;
struct TORCH_API EmbeddingBagOptions {
EmbeddingBagOptions(int64_t num_embeddings, int64_t embedding_dim);
TORCH_ARG(int64_t, num_embeddings);
TORCH_ARG(int64_t, embedding_dim);
TORCH_ARG(std::optional<double>, max_norm) = std::nullopt;
TORCH_ARG(double, norm_type) = 2.;
TORCH_ARG(bool, scale_grad_by_freq) = false;
TORCH_ARG(EmbeddingBagMode, mode) = torch::kMean;
TORCH_ARG(bool, sparse) = false;
TORCH_ARG(torch::Tensor, _weight);
TORCH_ARG(bool, include_last_offset) = false;
TORCH_ARG(std::optional<int64_t>, padding_idx) = std::nullopt;
};
// ============================================================================
struct TORCH_API EmbeddingBagFromPretrainedOptions {
TORCH_ARG(bool, freeze) = true;
TORCH_ARG(std::optional<double>, max_norm) = std::nullopt;
TORCH_ARG(double, norm_type) = 2.;
TORCH_ARG(bool, scale_grad_by_freq) = false;
TORCH_ARG(EmbeddingBagMode, mode) = torch::kMean;
TORCH_ARG(bool, sparse) = false;
TORCH_ARG(bool, include_last_offset) = false;
TORCH_ARG(std::optional<int64_t>, padding_idx) = std::nullopt;
};
// ============================================================================
namespace functional {
struct TORCH_API EmbeddingBagFuncOptions {
TORCH_ARG(torch::Tensor, offsets);
TORCH_ARG(std::optional<double>, max_norm) = std::nullopt;
TORCH_ARG(double, norm_type) = 2.;
TORCH_ARG(bool, scale_grad_by_freq) = false;
TORCH_ARG(EmbeddingBagMode, mode) = torch::kMean;
TORCH_ARG(bool, sparse) = false;
TORCH_ARG(torch::Tensor, per_sample_weights);
TORCH_ARG(bool, include_last_offset) = false;
TORCH_ARG(std::optional<int64_t>, padding_idx) = std::nullopt;
};
} // namespace functional
} // namespace torch::nn