Rate this Page

Program Listing for File linear.h#

Return to documentation for file (torch/csrc/api/include/torch/nn/options/linear.h)

#pragma once

#include <torch/arg.h>
#include <torch/csrc/Export.h>
#include <torch/types.h>

namespace torch::nn {

struct TORCH_API LinearOptions {
  LinearOptions(int64_t in_features, int64_t out_features);
  TORCH_ARG(int64_t, in_features);

  TORCH_ARG(int64_t, out_features);

  TORCH_ARG(bool, bias) = true;
};

// ============================================================================

struct TORCH_API FlattenOptions {
  TORCH_ARG(int64_t, start_dim) = 1;
  TORCH_ARG(int64_t, end_dim) = -1;
};

// ============================================================================

struct TORCH_API UnflattenOptions {
  typedef std::vector<std::pair<std::string, int64_t>> namedshape_t;

  UnflattenOptions(int64_t dim, std::vector<int64_t> sizes);
  UnflattenOptions(const char* dimname, namedshape_t namedshape);
  UnflattenOptions(std::string dimname, namedshape_t namedshape);

  TORCH_ARG(int64_t, dim);
  TORCH_ARG(std::string, dimname);
  TORCH_ARG(std::vector<int64_t>, sizes);
  TORCH_ARG(namedshape_t, namedshape);
};

// ============================================================================

struct TORCH_API BilinearOptions {
  BilinearOptions(
      int64_t in1_features,
      int64_t in2_features,
      int64_t out_features);
  TORCH_ARG(int64_t, in1_features);
  TORCH_ARG(int64_t, in2_features);
  TORCH_ARG(int64_t, out_features);
  TORCH_ARG(bool, bias) = true;
};

} // namespace torch::nn