Rate this Page

Program Listing for File tensor_struct.h#

Return to documentation for file (torch/csrc/stable/tensor_struct.h)

#pragma once

#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/headeronly/core/ScalarType.h>
#include <torch/headeronly/macros/Macros.h>
#include <torch/headeronly/util/Exception.h>
#include <torch/headeronly/util/HeaderOnlyArrayRef.h>
#include <torch/headeronly/util/shim_utils.h>
#include <climits>
#include <memory>

#include <torch/csrc/stable/accelerator.h>
#include <torch/csrc/stable/device_struct.h>

HIDDEN_NAMESPACE_BEGIN(torch, stable)

using accelerator::DeviceIndex;
using torch::headeronly::IntHeaderOnlyArrayRef;
using torch::headeronly::ScalarType;

// The torch::stable::Tensor class is a highlevel C++ wrapper around
// the C shim Tensor APIs. We've modeled this class after TensorBase, as custom
// op kernels only really need to interact with Tensor metadata (think sizes,
// strides, device, dtype). Other functions on Tensor (like empty_like) should
// live like the ATen op that they are and exist outside of this struct.
//
// There are several goals of this class over AtenTensorHandle and
// RAIIAtenTensorHandle:
// 1. torch::stable::Tensor is a nicer UX much closer to torch::Tensor than the
//    C APIs with AtenTensorHandle. Under the hood we still call to these C shim
//    APIs to preserve stability.
// 2. RAIIAtenTensorHandle boils down to a uniq_ptr that forces the user to pass
//    around ownership. This makes it difficult to pass one input into 2
//    different functions, e.g., doing something like c = a(t) + b(t) for
//    stable::Tensor t. Thus, we use a shared_ptr here.

class Tensor {
 private:
  std::shared_ptr<AtenTensorOpaque> ath_;

 public:
  Tensor() {
    AtenTensorHandle ret;
    TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&ret));
    ath_ = std::shared_ptr<AtenTensorOpaque>(ret, [](AtenTensorHandle ath) {
      TORCH_ERROR_CODE_CHECK(aoti_torch_delete_tensor_object(ath));
    });
  }

  explicit Tensor(AtenTensorHandle ath)
      : ath_(ath, [](AtenTensorHandle ath) {
          TORCH_ERROR_CODE_CHECK(aoti_torch_delete_tensor_object(ath));
        }) {}

  // Copy and move constructors can be default cuz the underlying handle is a
  // shared_ptr
  Tensor(const Tensor& other) = default;
  Tensor(Tensor&& other) noexcept = default;

  // Copy and move assignment operators can be default cuz the underlying handle
  // is a shared_ptr
  Tensor& operator=(const Tensor& other) = default;
  Tensor& operator=(Tensor&& other) noexcept = default;

  // Destructor can be default: shared ptr has custom deletion logic
  ~Tensor() = default;

  AtenTensorHandle get() const {
    return ath_.get();
  }

  // =============================================================================
  // C-shimified TensorBase APIs: the below APIs have the same signatures and
  // semantics as their counterparts in TensorBase.h.
  // =============================================================================

  void* data_ptr() const {
    void* data_ptr;
    TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(ath_.get(), &data_ptr));
    return data_ptr;
  }

#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
  void* mutable_data_ptr() const {
    void* data_ptr{};
    TORCH_ERROR_CODE_CHECK(torch_get_mutable_data_ptr(ath_.get(), &data_ptr));
    return data_ptr;
  }

  const void* const_data_ptr() const {
    const void* data_ptr{};
    TORCH_ERROR_CODE_CHECK(torch_get_const_data_ptr(ath_.get(), &data_ptr));
    return data_ptr;
  }

  template <typename T>
  T* mutable_data_ptr() const;

  template <typename T, std::enable_if_t<!std::is_const_v<T>, int> = 0>
  const T* const_data_ptr() const;

  const Tensor& set_requires_grad(bool requires_grad) const {
    TORCH_ERROR_CODE_CHECK(torch_set_requires_grad(ath_.get(), requires_grad));
    return *this;
  }
#endif // TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0

  int64_t dim() const {
    int64_t dim;
    TORCH_ERROR_CODE_CHECK(aoti_torch_get_dim(ath_.get(), &dim));
    return dim;
  }

  int64_t numel() const {
    int64_t numel;
    TORCH_ERROR_CODE_CHECK(aoti_torch_get_numel(ath_.get(), &numel));
    return numel;
  }

  // note: sizes and strides, for all intents and purposes, the same as in
  // TensorBase.h: it returns a borrowed reference of the dimension sizes of
  // a Tensor.
  //
  // The only difference is that it returns a header-only IntHeaderOnlyArrayRef,
  // which has slightly less functionality than a regular IntArrayRef. See
  // [HeaderOnlyArrayRef vs ArrayRef note] for more details.
  IntHeaderOnlyArrayRef sizes() const {
    int64_t* sizes;
    TORCH_ERROR_CODE_CHECK(aoti_torch_get_sizes(ath_.get(), &sizes));
    return IntHeaderOnlyArrayRef(sizes, dim());
  }

  IntHeaderOnlyArrayRef strides() const {
    int64_t* strides;
    TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides(ath_.get(), &strides));
    return IntHeaderOnlyArrayRef(strides, dim());
  }

  bool is_contiguous() const {
    bool is_contiguous;
    TORCH_ERROR_CODE_CHECK(
        aoti_torch_is_contiguous(ath_.get(), &is_contiguous));
    return is_contiguous;
  }

  int64_t stride(int64_t dim) const {
    int64_t stride;
    TORCH_ERROR_CODE_CHECK(aoti_torch_get_stride(ath_.get(), dim, &stride));
    return stride;
  }

  // This is almost the same API as the one in TensorBase.h, except
  // we add a check that the returned device_index is within the
  // range of int8_t.
  int8_t get_device() const {
    int32_t device_index;
    TORCH_ERROR_CODE_CHECK(
        aoti_torch_get_device_index(ath_.get(), &device_index));
    STD_TORCH_CHECK(
        device_index >= std::numeric_limits<int8_t>::min() &&
            device_index <= std::numeric_limits<int8_t>::max(),
        "Device index is out of range of return type int8_t, please use get_device_index() instead.");
    return static_cast<int8_t>(device_index);
  }

  // The same as get_device but with two differences:
  // 1. it has a more suiting name
  // 2. it returns a DeviceIndex, which is int32_t in this world
  //    that should be more stable than the likely shifting
  //    DeviceIndex in libtorch (it is int8_t that might become int16_t)
  DeviceIndex get_device_index() const {
    int32_t device_index;
    TORCH_ERROR_CODE_CHECK(
        aoti_torch_get_device_index(ath_.get(), &device_index));
    return device_index;
  }

  bool is_cuda() const {
    int32_t device_type;
    TORCH_ERROR_CODE_CHECK(
        aoti_torch_get_device_type(ath_.get(), &device_type));
    return device_type == aoti_torch_device_type_cuda();
  }

  bool is_cpu() const {
    int32_t device_type;
    TORCH_ERROR_CODE_CHECK(
        aoti_torch_get_device_type(ath_.get(), &device_type));
    return device_type == aoti_torch_device_type_cpu();
  }

  int64_t size(int64_t dim) const {
    int64_t size;
    TORCH_ERROR_CODE_CHECK(aoti_torch_get_size(ath_.get(), dim, &size));
    return size;
  }

  bool defined() const {
    bool defined;
    TORCH_ERROR_CODE_CHECK(aoti_torch_is_defined(ath_.get(), &defined));
    return defined;
  }

  int64_t storage_offset() const {
    int64_t storage_offset;
    TORCH_ERROR_CODE_CHECK(
        aoti_torch_get_storage_offset(ath_.get(), &storage_offset));
    return storage_offset;
  }

  size_t element_size() const {
    int32_t dtype;
    TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(ath_.get(), &dtype));
    return aoti_torch_dtype_element_size(dtype);
  }

  // defined in tensor-inl.h to avoid circular dependencies
  ScalarType scalar_type() const;

  // defined in tensor-inl.h to avoid circular dependencies
  Device device() const;

  // =============================================================================
  // END of C-shimified TensorBase APIs
  // =============================================================================
};

HIDDEN_NAMESPACE_END(torch, stable)