Rate this Page

Torch Stable API#

The PyTorch Stable C++ API provides a convenient high level interface to call ABI-stable tensor operations and other utilities commonly used in custom operators. These functions are designed to maintain binary compatibility across PyTorch versions, making them suitable for use in ahead-of-time compiled code.

For more information on the stable ABI, see the Stable ABI notes.

Library Registration Macros#

These macros provide stable ABI equivalents of the standard PyTorch operator registration macros (TORCH_LIBRARY, TORCH_LIBRARY_IMPL, etc.). Use these when building custom operators that need to maintain binary compatibility across PyTorch versions.

STABLE_TORCH_LIBRARY(ns, m)#

Defines a library of operators in a namespace using the stable ABI.

This is the stable ABI equivalent of TORCH_LIBRARY. Use this macro to define operator schemas that will maintain binary compatibility across PyTorch versions. Only one STABLE_TORCH_LIBRARY block can exist per namespace; use STABLE_TORCH_LIBRARY_FRAGMENT for additional definitions in the same namespace from different translation units.

Parameters:

  • ns - The namespace in which to define operators (e.g., mylib).

  • m - The name of the StableLibrary variable available in the block.

Example:

STABLE_TORCH_LIBRARY(mylib, m) {
    m.def("my_op(Tensor input, int size) -> Tensor");
    m.def("another_op(Tensor a, Tensor b) -> Tensor");
}

Minimum compatible version: PyTorch 2.9.

STABLE_TORCH_LIBRARY_IMPL(ns, k, m)#

Registers operator implementations for a specific dispatch key using the stable ABI.

This is the stable ABI equivalent of TORCH_LIBRARY_IMPL. Use this macro to provide implementations of operators for a specific dispatch key (e.g., CPU, CUDA) while maintaining binary compatibility across PyTorch versions.

Note

All kernel functions registered with this macro must be boxed using the TORCH_BOX macro.

Parameters:

  • ns - The namespace in which the operators are defined.

  • k - The dispatch key (e.g., CPU, CUDA).

  • m - The name of the StableLibrary variable available in the block.

Example:

STABLE_TORCH_LIBRARY_IMPL(mylib, CPU, m) {
    m.impl("my_op", TORCH_BOX(&my_cpu_kernel));
}

STABLE_TORCH_LIBRARY_IMPL(mylib, CUDA, m) {
    m.impl("my_op", TORCH_BOX(&my_cuda_kernel));
}

Minimum compatible version: PyTorch 2.9.

STABLE_TORCH_LIBRARY_FRAGMENT(ns, m)#

Extends operator definitions in an existing namespace using the stable ABI.

This is the stable ABI equivalent of TORCH_LIBRARY_FRAGMENT. Use this macro to add additional operator definitions to a namespace that was already created with STABLE_TORCH_LIBRARY.

Parameters:

  • ns - The namespace to extend.

  • m - The name of the StableLibrary variable available in the block.

Minimum compatible version: PyTorch 2.9.

TORCH_BOX(&func)#

Wraps a function to conform to the stable boxed kernel calling convention.

This macro takes an unboxed kernel function pointer and generates a boxed wrapper that can be registered with the stable library API.

Parameters:

  • func - The unboxed kernel function to wrap.

Example:

Tensor my_kernel(const Tensor& input, int64_t size) {
    return input.reshape({size});
}

STABLE_TORCH_LIBRARY_IMPL(my_namespace, CPU, m) {
    m.impl("my_op", TORCH_BOX(&my_kernel));
}

Minimum compatible version: PyTorch 2.9.

Tensor Class#

The torch::stable::Tensor class offers a user-friendly C++ interface similar to torch::Tensor while maintaining binary compatibility across PyTorch versions.

class Tensor#

An ABI stable wrapper around PyTorch tensors.

This class is modeled after TensorBase, as custom op kernels primarily need to interact with Tensor metadata (sizes, strides, device, dtype). Other tensor operations (like empty_like) exist as standalone functions outside of this struct.

Minimum compatible version: PyTorch 2.9.

Public Functions

inline Tensor()#

Constructs a Tensor with an uninitialized AtenTensorHandle.

Creates a new stable::Tensor by allocating an uninitialized tensor handle. The ownership of the handle is managed internally via shared_ptr.

Minimum compatible version: PyTorch 2.9.

inline explicit Tensor(AtenTensorHandle ath)#

Constructs a Tensor from an existing AtenTensorHandle.

Steals ownership of the provided AtenTensorHandle.

Minimum compatible version: PyTorch 2.9.

Parameters:

ath – The AtenTensorHandle to wrap. Ownership is transferred to this Tensor.

inline AtenTensorHandle get() const#

Returns a borrowed reference to the underlying AtenTensorHandle.

Minimum compatible version: PyTorch 2.9.

Returns:

The underlying AtenTensorHandle.

inline void *data_ptr() const#

Returns a pointer to the tensor’s data.

Minimum compatible version: PyTorch 2.9.

Returns:

A void pointer to the tensor’s data storage.

inline void *mutable_data_ptr() const#

Returns a mutable pointer to the tensor’s data.

Minimum compatible version: PyTorch 2.10.

Returns:

A mutable void pointer to the tensor’s data storage.

inline const void *const_data_ptr() const#

Returns a const pointer to the tensor’s data.

Minimum compatible version: PyTorch 2.10.

Returns:

A const void pointer to the tensor’s data storage.

template<typename T>
T *mutable_data_ptr() const#

Returns a typed mutable pointer to the tensor’s data.

Minimum compatible version: PyTorch 2.10.

Template Parameters:

T – The type to cast the data pointer to.

Returns:

A mutable pointer to the tensor’s data cast to type T*.

template<typename T, std::enable_if_t<!std::is_const_v<T>, int> = 0>
const T *const_data_ptr() const#

Returns a typed const pointer to the tensor’s data.

Minimum compatible version: PyTorch 2.10.

Template Parameters:

T – The type to cast the data pointer to. Must not be const-qualified.

Returns:

A const pointer to the tensor’s data cast to type const T*.

inline const Tensor &set_requires_grad(bool requires_grad) const#

Sets whether this tensor requires gradient computation.

Minimum compatible version: PyTorch 2.10.

Parameters:

requires_grad – If true, gradients will be computed for this tensor during backpropagation.

Returns:

A reference to this Tensor.

inline int64_t dim() const#

Returns the number of dimensions of the tensor.

Minimum compatible version: PyTorch 2.9.

Returns:

The number of dimensions (rank) of the tensor.

inline int64_t numel() const#

Returns the total number of elements in the tensor.

Minimum compatible version: PyTorch 2.9.

Returns:

The total number of elements across all dimensions.

inline IntHeaderOnlyArrayRef sizes() const#

Returns the sizes (shape) of the tensor.

Returns a borrowed reference of the dimension sizes of the tensor.

Minimum compatible version: PyTorch 2.9.

Returns:

An IntHeaderOnlyArrayRef containing the size of each dimension.

inline IntHeaderOnlyArrayRef strides() const#

Returns the strides of the tensor.

Returns a borrowed reference of the strides of the tensor.

Minimum compatible version: PyTorch 2.9.

Returns:

An IntHeaderOnlyArrayRef containing the stride of each dimension.

inline bool is_contiguous() const#

Checks if the tensor is contiguous in memory.

Minimum compatible version: PyTorch 2.9.

Note

This is a subset of the original TensorBase API. It takes no arguments whereas the original API takes a memory format argument. Here, we assume the default contiguous memory format.

Returns:

true if the tensor is contiguous, false otherwise.

inline int64_t stride(int64_t dim) const#

Returns the stride of a specific dimension.

Minimum compatible version: PyTorch 2.9.

Parameters:

dim – The dimension index to query.

Returns:

The stride of the specified dimension.

inline DeviceIndex get_device_index() const#

Returns the device index of the tensor.

Minimum compatible version: PyTorch 2.9.

Returns:

The device index as DeviceIndex (int32_t).

inline bool is_cuda() const#

Checks if the tensor is on a CUDA device.

Minimum compatible version: PyTorch 2.9.

Returns:

true if the tensor is on a CUDA device, false otherwise.

inline bool is_cpu() const#

Checks if the tensor is on the CPU.

Minimum compatible version: PyTorch 2.9.

Returns:

true if the tensor is on the CPU, false otherwise.

inline int64_t size(int64_t dim) const#

Returns the size of a specific dimension.

Minimum compatible version: PyTorch 2.9.

Parameters:

dim – The dimension index to query.

Returns:

The size of the specified dimension.

inline bool defined() const#

Checks if the tensor is defined (not null).

Minimum compatible version: PyTorch 2.9.

Returns:

true if the tensor is defined, false otherwise.

inline int64_t storage_offset() const#

Returns the storage offset of the tensor.

The storage offset is the number of elements from the beginning of the underlying storage to the first element of the tensor.

Minimum compatible version: PyTorch 2.9.

Returns:

The storage offset in number of elements.

inline size_t element_size() const#

Returns the size in bytes of each element in the tensor.

Minimum compatible version: PyTorch 2.9.

Returns:

The element size in bytes.

ScalarType scalar_type() const#

Returns the scalar type (dtype) of the tensor.

Minimum compatible version: PyTorch 2.9.

Returns:

The ScalarType of the tensor.

Device device() const#

Returns the device of the tensor.

Minimum compatible version: PyTorch 2.9.

Returns:

The Device on which the tensor resides.

Device Class#

The torch::stable::Device class provides a user-friendly C++ interface similar to c10::Device while maintaining binary compatibility across PyTorch versions. It represents a compute device (CPU, CUDA, etc.) with an optional device index.

class Device#

A stable version of c10::Device.

Minimum compatible version: PyTorch 2.9.

Public Functions

inline Device(DeviceType type, DeviceIndex index = -1)#

Constructs a Device from a DeviceType and optional device index.

Minimum compatible version: PyTorch 2.9.

Parameters:
  • type – The type of device (e.g., DeviceType::CPU, DeviceType::CUDA).

  • index – The device index. Default is -1 (current device).

Device(const std::string &device_string)#

Constructs a stable::Device from a string description.

The string must follow the schema: (cpu|cuda|…)[:<device-index>]

Minimum compatible version: PyTorch 2.10.

Parameters:

device_string – A string describing the device (e.g., “cuda:0”, “cpu”).

inline bool operator==(const Device &other) const noexcept#

Checks if two devices are equal.

Minimum compatible version: PyTorch 2.9.

Parameters:

other – The device to compare with.

Returns:

true if both type and index match, false otherwise.

inline bool operator!=(const Device &other) const noexcept#

Checks if two devices are not equal.

Minimum compatible version: PyTorch 2.9.

Parameters:

other – The device to compare with.

Returns:

true if type or index differ, false otherwise.

inline void set_index(DeviceIndex index)#

Sets the device index.

Minimum compatible version: PyTorch 2.9.

Parameters:

index – The new device index.

inline DeviceType type() const noexcept#

Returns the device type.

Minimum compatible version: PyTorch 2.9.

Returns:

The DeviceType of this device.

inline DeviceIndex index() const noexcept#

Returns the device index.

Minimum compatible version: PyTorch 2.9.

Returns:

The device index, or -1 if no specific index is set.

inline bool has_index() const noexcept#

Checks if this device has a specific index.

Minimum compatible version: PyTorch 2.9.

Returns:

true if index is not -1, false otherwise.

inline bool is_cuda() const noexcept#

Checks if this is a CUDA device.

Minimum compatible version: PyTorch 2.9.

Returns:

true if the device type is CUDA, false otherwise.

inline bool is_cpu() const noexcept#

Checks if this is a CPU device.

Minimum compatible version: PyTorch 2.9.

Returns:

true if the device type is CPU, false otherwise.

DeviceGuard Class#

The torch::stable::accelerator::DeviceGuard provides a user-friendly C++ interface similar to c10::DeviceGuard while maintaining binary compatibility across PyTorch versions.

class DeviceGuard#

A stable ABI version of c10::DeviceGuard.

RAII class that sets the current device to the specified device index on construction and restores the previous device on destruction.

Minimum compatible version: PyTorch 2.9.

Public Functions

inline explicit DeviceGuard(DeviceIndex device_index)#

Constructs a DeviceGuard that sets the current device.

Minimum compatible version: PyTorch 2.9.

Parameters:

device_index – The device index to set as the current device.

inline void set_index(DeviceIndex device_index)#

Changes the current device to the specified device index.

Minimum compatible version: PyTorch 2.9.

Parameters:

device_index – The new device index to set.

inline DeviceIndex torch::stable::accelerator::getCurrentDeviceIndex()#

Gets the current device index.

Returns the index of the currently active device for the accelerator.

Minimum compatible version: PyTorch 2.9.

Returns:

The current device index.

Stream Utilities#

For CUDA stream access, we currently recommend the ABI stable C shim API. This will be improved in a future release with a more ergonomic wrapper.

Getting the Current CUDA Stream#

To obtain the current cudaStream_t for use in CUDA kernels:

#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/headeronly/util/shim_utils.h>

// For now, we rely on the ABI stable C shim API to get the current CUDA stream.
// This will be improved in a future release.
// When using a C shim API, we need to use TORCH_ERROR_CODE_CHECK to
// check the error code and throw an appropriate runtime_error otherwise.
void* stream_ptr = nullptr;
TORCH_ERROR_CODE_CHECK(
    aoti_torch_get_current_cuda_stream(tensor.get_device_index(), &stream_ptr));
cudaStream_t stream = static_cast<cudaStream_t>(stream_ptr);

// Now you can use 'stream' in your CUDA kernel launches
my_kernel<<<blocks, threads, 0, stream>>>(args...);

Note

The TORCH_ERROR_CODE_CHECK macro is required when using C shim APIs to properly check error codes and throw appropriate exceptions.

CUDA Error Checking Macros#

These macros provide stable ABI equivalents for CUDA error checking. They wrap CUDA API calls and kernel launches, providing detailed error messages using PyTorch’s error formatting.

STD_CUDA_CHECK(EXPR)#

Checks the result of a CUDA API call and throws an exception on error. Users of this macro are expected to include cuda_runtime.h.

Example:

STD_CUDA_CHECK(cudaMalloc(&ptr, size));
STD_CUDA_CHECK(cudaMemcpy(dst, src, size, cudaMemcpyDeviceToHost));

Minimum compatible version: PyTorch 2.10.

STD_CUDA_KERNEL_LAUNCH_CHECK()#

Checks for errors from the most recent CUDA kernel launch. Equivalent to STD_CUDA_CHECK(cudaGetLastError()).

Example:

my_kernel<<<blocks, threads, 0, stream>>>(args...);
STD_CUDA_KERNEL_LAUNCH_CHECK();

Minimum compatible version: PyTorch 2.10.

Header-Only Utilities#

The torch::headeronly namespace provides header-only versions of common PyTorch types and utilities. These can be used without linking against libtorch, making them ideal for maintaining binary compatibility across PyTorch versions.

Error Checking#

STD_TORCH_CHECK is a header-only macro for runtime assertions:

#include <torch/headeronly/util/Exception.h>

STD_TORCH_CHECK(condition, "Error message with ", variable, " interpolation");

Core Types#

The following c10:: types are available as header-only versions under torch::headeronly:::

  • torch::headeronly::ScalarType - Tensor data types (Float, Double, Int, etc.)

  • torch::headeronly::DeviceType - Device types (CPU, CUDA, etc.)

  • torch::headeronly::MemoryFormat - Memory layout formats (Contiguous, ChannelsLast, etc.)

  • torch::headeronly::Layout - Tensor layouts (Strided, Sparse, etc.)

#include <torch/headeronly/core/ScalarType.h>
#include <torch/headeronly/core/DeviceType.h>
#include <torch/headeronly/core/MemoryFormat.h>
#include <torch/headeronly/core/Layout.h>

auto dtype = torch::headeronly::ScalarType::Float;
auto device_type = torch::headeronly::DeviceType::CUDA;
auto memory_format = torch::headeronly::MemoryFormat::Contiguous;
auto layout = torch::headeronly::Layout::Strided;

TensorAccessor#

TensorAccessor provides efficient, bounds-checked access to tensor data. You can construct one from a stable tensor’s data pointer, sizes, and strides:

#include <torch/headeronly/core/TensorAccessor.h>

// Create a TensorAccessor for a 2D float tensor
auto sizes = tensor.sizes();
auto strides = tensor.strides();
torch::headeronly::TensorAccessor<float, 2> accessor(
    static_cast<float*>(tensor.mutable_data_ptr()),
    sizes.data(),
    strides.data());

// Access elements
float value = accessor[i][j];

Dispatch Macros#

Header-only dispatch macros (THO = Torch Header Only) are available for dtype and device dispatching:

#include <torch/headeronly/core/Dispatch.h>

THO_DISPATCH_FLOATING_TYPES(tensor.scalar_type(), "my_kernel", [&] {
    // scalar_t is the resolved type
    auto* data = tensor.data_ptr<scalar_t>();
});

Full API List#

For the complete list of header-only APIs, see torch/header_only_apis.txt in the PyTorch source tree.

Stable Operators#

Tensor Creation#

inline torch::stable::Tensor torch::stable::empty(torch::headeronly::IntHeaderOnlyArrayRef size, std::optional<torch::headeronly::ScalarType> dtype = std::nullopt, std::optional<torch::headeronly::Layout> layout = std::nullopt, std::optional<torch::stable::Device> device = std::nullopt, std::optional<bool> pin_memory = std::nullopt, std::optional<torch::headeronly::MemoryFormat> memory_format = std::nullopt)#

Stable version of the empty.memory_format op.

Creates a new uninitialized tensor with the specified size and options. This function supports full tensor creation options including device, dtype, layout, and memory format.

Minimum compatible version: PyTorch 2.10.

Parameters:
  • size – The desired size of the output tensor.

  • dtype – Optional scalar type for the tensor elements.

  • layout – Optional memory layout (e.g., strided, sparse).

  • device – Optional device to place the tensor on.

  • pin_memory – Optional flag to use pinned memory (for CUDA tensors).

  • memory_format – Optional memory format for the tensor.

Returns:

A new uninitialized tensor with the specified properties.

inline torch::stable::Tensor torch::stable::empty_like(const torch::stable::Tensor &self)#

Stable version of the empty_like op.

Creates a new uninitialized tensor with the same size, dtype, layout, and device as the input tensor. This version does not support kwargs (device, dtype, layout, memory_format) - kwargs support may be added in the future.

Minimum compatible version: PyTorch 2.9.

Parameters:

self – The input tensor whose properties will be used for the new tensor.

Returns:

A new uninitialized tensor with the same properties as self.

inline torch::stable::Tensor torch::stable::new_empty(const torch::stable::Tensor &self, torch::headeronly::IntHeaderOnlyArrayRef size, std::optional<torch::headeronly::ScalarType> dtype = std::nullopt, std::optional<torch::headeronly::Layout> layout = std::nullopt, std::optional<torch::stable::Device> device = std::nullopt, std::optional<bool> pin_memory = std::nullopt)#

Stable version of the new_empty op (2.10 version with full kwargs).

Creates a new uninitialized tensor with the specified size and options. This version supports all tensor creation kwargs. For versions < 2.10, a simpler overload that only takes dtype is available.

Minimum compatible version: PyTorch 2.10.

Parameters:
  • self – The input tensor whose properties may be inherited if kwargs are not provided.

  • size – The desired size of the output tensor.

  • dtype – Optional scalar type for the tensor elements.

  • layout – Optional memory layout (e.g., strided, sparse).

  • device – Optional device to place the tensor on.

  • pin_memory – Optional flag to use pinned memory (for CUDA tensors).

Returns:

A new uninitialized tensor with the specified properties.

inline torch::stable::Tensor torch::stable::new_zeros(const torch::stable::Tensor &self, torch::headeronly::IntHeaderOnlyArrayRef size, std::optional<torch::headeronly::ScalarType> dtype = std::nullopt, std::optional<torch::headeronly::Layout> layout = std::nullopt, std::optional<torch::stable::Device> device = std::nullopt, std::optional<bool> pin_memory = std::nullopt)#

Stable version of the new_zeros op (2.10 version with full kwargs).

Creates a new zero-filled tensor with the specified size and options. This version supports all tensor creation kwargs. For versions < 2.10, a simpler overload that only takes dtype is available.

Minimum compatible version: PyTorch 2.10.

Parameters:
  • self – The input tensor whose properties may be inherited if kwargs are not provided.

  • size – The desired size of the output tensor.

  • dtype – Optional scalar type for the tensor elements.

  • layout – Optional memory layout (e.g., strided, sparse).

  • device – Optional device to place the tensor on.

  • pin_memory – Optional flag to use pinned memory (for CUDA tensors).

Returns:

A new zero-filled tensor with the specified properties.

inline torch::stable::Tensor torch::stable::full(torch::headeronly::IntHeaderOnlyArrayRef size, double fill_value, std::optional<torch::headeronly::ScalarType> dtype = std::nullopt, std::optional<torch::headeronly::Layout> layout = std::nullopt, std::optional<torch::stable::Device> device = std::nullopt, std::optional<bool> pin_memory = std::nullopt)#

Stable version of the full.default op.

Creates a tensor of the specified size filled with the given value.

Minimum compatible version: PyTorch 2.10.

Note

The fill_value parameter is typed C shim API uses double for the Scalar parameter.

Parameters:
  • size – The desired size of the output tensor.

  • fill_value – The value to fill the tensor with.

  • dtype – Optional scalar type for the tensor elements.

  • layout – Optional memory layout.

  • device – Optional device to place the tensor on.

  • pin_memory – Optional flag to use pinned memory.

Returns:

A new tensor filled with the specified value.

inline torch::stable::Tensor torch::stable::from_blob(void *data, torch::headeronly::IntHeaderOnlyArrayRef sizes, torch::headeronly::IntHeaderOnlyArrayRef strides, torch::stable::Device device, torch::headeronly::ScalarType dtype, int64_t storage_offset = 0, torch::headeronly::Layout layout = torch::headeronly::Layout::Strided)#

Creates a tensor from an existing data blob.

Creates a tensor that uses the provided data pointer as its storage. The tensor does not own the data, so the caller must ensure the data remains valid for the lifetime of the tensor.

Minimum compatible version: PyTorch 2.10.

Parameters:
  • data – Pointer to the data buffer.

  • sizes – The size of each dimension of the tensor.

  • strides – The stride for each dimension.

  • device – The device where the data resides.

  • dtype – The scalar type of the data.

  • storage_offset – The offset into the data buffer. Defaults to 0.

  • layout – The memory layout. Defaults to Strided.

Returns:

A tensor backed by the provided data.

Tensor Manipulation#

inline torch::stable::Tensor torch::stable::clone(const torch::stable::Tensor &self)#

Stable version of the clone op.

Returns a copy of the input tensor. The returned tensor has the same data and type as the input, but is stored in a new memory location.

Minimum compatible version: PyTorch 2.9.

Note

Optional memory_format kwarg support

Parameters:

self – The input tensor to clone.

Returns:

A new tensor with copied data.

inline torch::stable::Tensor torch::stable::contiguous(const torch::stable::Tensor &self, torch::headeronly::MemoryFormat memory_format = torch::headeronly::MemoryFormat::Contiguous)#

Stable version of the contiguous op.

Returns a contiguous in memory tensor containing the same data as the input tensor. If the input tensor is already contiguous in the specified memory format, the input tensor is returned.

Minimum compatible version: PyTorch 2.10.

Parameters:
  • self – The input tensor.

  • memory_format – The desired memory format.

Returns:

A contiguous tensor.

inline torch::stable::Tensor torch::stable::reshape(const torch::stable::Tensor &self, torch::headeronly::IntHeaderOnlyArrayRef shape)#

Stable version of the reshape op.

Returns a tensor with the same data and number of elements as the input, but with the specified shape. When possible, the returned tensor will be a view of the input.

Minimum compatible version: PyTorch 2.10.

Parameters:
  • self – The input tensor.

  • shape – The desired output shape.

Returns:

A tensor with the specified shape.

inline torch::stable::Tensor torch::stable::view(const torch::stable::Tensor &self, torch::headeronly::IntHeaderOnlyArrayRef size)#

Stable version of the view op.

Returns a new tensor with the same data as the input tensor but with a different shape. The returned tensor shares the same data and must have the same number of elements.

Minimum compatible version: PyTorch 2.10.

Parameters:
  • self – The input tensor.

  • size – The desired output shape.

Returns:

A view tensor with the specified shape.

inline torch::stable::Tensor torch::stable::flatten(const torch::stable::Tensor &self, int64_t start_dim = 0, int64_t end_dim = -1)#

Stable version of the flatten.using_ints op.

Flattens the input tensor by reshaping it into a one-dimensional tensor. If start_dim or end_dim are specified, only dimensions starting from start_dim to end_dim are flattened.

Minimum compatible version: PyTorch 2.9.

Parameters:
  • self – The input tensor to flatten.

  • start_dim – The first dimension to flatten. Defaults to 0.

  • end_dim – The last dimension to flatten. Defaults to -1 (last dim).

Returns:

A flattened tensor.

inline torch::stable::Tensor torch::stable::squeeze(const torch::stable::Tensor &self, int64_t dim)#

Stable version of the squeeze.dim op.

Returns a tensor with the dimension of size one at the specified position removed. The returned tensor shares the same underlying data with the input tensor.

Minimum compatible version: PyTorch 2.9.

Parameters:
  • self – The input tensor.

  • dim – The dimension to squeeze. the tensor is returned unchanged.

Returns:

A tensor with the specified dimension removed (if size was 1).

inline torch::stable::Tensor torch::stable::unsqueeze(const torch::stable::Tensor &self, int64_t dim)#

Stable version of the unsqueeze op.

Returns a new tensor with a dimension of size one inserted at the specified position. The returned tensor shares the same underlying data with the input tensor.

Minimum compatible version: PyTorch 2.9.

Parameters:
  • self – The input tensor.

  • dim – The index at which to insert values are supported.

Returns:

A tensor with an additional dimension.

inline torch::stable::Tensor torch::stable::transpose(const torch::stable::Tensor &self, int64_t dim0, int64_t dim1)#

Stable version of the transpose.int op.

Returns a tensor that is a transposed version of the input, with dimensions dim0 and dim1 swapped. The returned tensor shares storage with the input.

Minimum compatible version: PyTorch 2.9.

Parameters:
  • self – The input tensor.

  • dim0 – The first dimension to transpose.

  • dim1 – The second dimension to transpose.

Returns:

A transposed view of the input tensor.

inline torch::stable::Tensor torch::stable::select(const torch::stable::Tensor &self, int64_t dim, int64_t index)#

Stable version of the select.int op.

Slices the input tensor along the specified dimension at the given index. This function returns a view of the original tensor with the given dimension removed.

Minimum compatible version: PyTorch 2.9.

Note

The index parameter is typed header-only.

Parameters:
  • self – The input tensor.

  • dim – The dimension to slice.

  • index – The index to select along the dimension.

Returns:

A tensor with one fewer dimension.

inline torch::stable::Tensor torch::stable::narrow(torch::stable::Tensor &self, int64_t dim, int64_t start, int64_t length)#

Stable version of the narrow.default op.

Returns a new tensor that is a narrowed version of the input tensor. The dimension dim is narrowed from start to start + length.

Minimum compatible version: PyTorch 2.9.

Note

The start and length parameters is not yet header-only.

Parameters:
  • self – The input tensor to narrow.

  • dim – The dimension along which to narrow.

  • start – The starting index for the narrowed dimension.

  • length – The length of the narrowed dimension.

Returns:

A new tensor that is a narrowed view of the input.

inline torch::stable::Tensor torch::stable::pad(const torch::stable::Tensor &self, torch::headeronly::IntHeaderOnlyArrayRef pad, const std::string &mode = "constant", double value = 0.0)#

Stable version of the pad.default op.

Pads the input tensor according to the specified padding sizes. The padding is applied symmetrically to each dimension, with the padding sizes specified in reverse order (last dimension first).

Minimum compatible version: PyTorch 2.9.

Note

The pad parameter is typed not yet header-only.

Parameters:
  • self – The input tensor to pad.

  • pad – The padding sizes for each dimension (in pairs, starting from the last dimension).

  • mode – The padding mode: “constant”, “reflect”, “replicate”, or “circular”. Defaults to “constant”.

  • value – The fill value for constant padding. Defaults to 0.0.

Returns:

A new padded tensor.

Device and Type Conversion#

inline torch::stable::Tensor torch::stable::to(const torch::stable::Tensor &self, std::optional<torch::headeronly::ScalarType> dtype = std::nullopt, std::optional<torch::headeronly::Layout> layout = std::nullopt, std::optional<torch::stable::Device> device = std::nullopt, std::optional<bool> pin_memory = std::nullopt, bool non_blocking = false, bool copy = false, std::optional<torch::headeronly::MemoryFormat> memory_format = std::nullopt)#

Stable version of the to.dtype_layout op.

Converts a tensor to the specified dtype, layout, device, and/or memory format. Returns a new tensor with the specified properties.

Minimum compatible version: PyTorch 2.10.

Parameters:
  • self – The input tensor.

  • dtype – Optional target scalar type.

  • layout – Optional target memory layout.

  • device – Optional target device.

  • pin_memory – Optional flag to use pinned memory.

  • non_blocking – If true, the operation may be asynchronous. Defaults to false.

  • copy – If true, always create a copy. Defaults to false.

  • memory_format – Optional target memory format.

Returns:

A tensor with the specified properties.

inline torch::stable::Tensor torch::stable::to(const torch::stable::Tensor &self, torch::stable::Device device, bool non_blocking = false, bool copy = false)#

Convenience overload for moving a tensor to a device.

Moves the tensor to the specified device. This is a convenience wrapper around the full to() function.

Minimum compatible version: PyTorch 2.10.

Parameters:
  • self – The input tensor.

  • device – The target device.

  • non_blocking – If true, the operation may be asynchronous. Defaults to false.

  • copy – If true, always create a copy. Defaults to false.

Returns:

A tensor on the specified device.

inline torch::stable::Tensor torch::stable::fill_(const torch::stable::Tensor &self, double value)#

Stable version of the fill_.Scalar op.

Fills the input tensor with the specified scalar value in-place and returns it. This has identical semantics to the existing fill_.Scalar op.

Minimum compatible version: PyTorch 2.9.

Note

The value parameter is typed as double This is because Scalar.h is currently not header-only.

Parameters:
  • self – The tensor to fill.

  • value – The scalar value to fill the tensor with.

Returns:

The input tensor, now filled with the specified value.

inline torch::stable::Tensor torch::stable::zero_(torch::stable::Tensor &self)#

Stable version of the zero_ op.

Fills the input tensor with zeros in-place and returns it. Unlike the tensor method version (t.zero_()), this is called as a function: zero_(t).

Minimum compatible version: PyTorch 2.9.

Parameters:

self – The tensor to fill with zeros.

Returns:

The input tensor, now filled with zeros.

inline torch::stable::Tensor torch::stable::copy_(torch::stable::Tensor &self, const torch::stable::Tensor &src, std::optional<bool> non_blocking = std::nullopt)#

Stable version of the copy_ op.

Copies the elements from the source tensor into the destination tensor in-place and returns the destination tensor. The tensors must be broadcastable.

Minimum compatible version: PyTorch 2.9.

Parameters:
  • self – The destination tensor (modified in-place).

  • src – The source tensor to copy from.

  • non_blocking – If true, the copy may occur asynchronously with respect to the host. Defaults to false.

Returns:

The destination tensor with copied values.

inline torch::stable::Tensor torch::stable::matmul(const torch::stable::Tensor &self, const torch::stable::Tensor &other)#

Stable version of the matmul op.

Performs matrix multiplication between two tensors. The behavior depends on the dimensionality of the tensors (see PyTorch documentation for details on broadcasting rules for matmul).

Minimum compatible version: PyTorch 2.9.

Parameters:
  • self – The first input tensor.

  • other – The second input tensor.

Returns:

The result of matrix multiplication.

inline torch::stable::Tensor torch::stable::amax(const torch::stable::Tensor &self, int64_t dim, bool keepdim = false)#

Stable version of the amax.default op (single dimension).

Computes the maximum value along the specified dimension. If keepdim is true, the output tensor has the same number of dimensions as the input, with the reduced dimension having size 1. Otherwise, the reduced dimension is removed.

Minimum compatible version: PyTorch 2.9.

Parameters:
  • self – The input tensor.

  • dim – The dimension along which to compute the maximum.

  • keepdim – Whether to retain

Returns:

A tensor containing the maximum values along the specified dimension.

inline torch::stable::Tensor torch::stable::amax(const torch::stable::Tensor &self, torch::headeronly::IntHeaderOnlyArrayRef dims, bool keepdim = false)#

Stable version of the amax.default op (multiple dimensions).

Computes the maximum value reducing over all the specified dimensions. If keepdim is true, the output tensor has the same number of dimensions as the input, with the reduced dimensions having size 1. Otherwise, the reduced dimensions are removed.

Minimum compatible version: PyTorch 2.9.

Note

The dims parameter is typed is not yet header-only.

Parameters:
  • self – The input tensor.

  • dims – The dimensions along which to compute the maximum.

  • keepdim – Whether to retain the reduced dimensions. Defaults to false.

Returns:

A tensor containing the maximum values.

inline torch::stable::Tensor torch::stable::sum(const torch::stable::Tensor &self, std::optional<torch::headeronly::IntHeaderOnlyArrayRef> dim = std::nullopt, bool keepdim = false, std::optional<torch::headeronly::ScalarType> dtype = std::nullopt)#

Stable version of the sum.dim_IntList op.

Computes the sum of the input tensor along the specified dimensions. If dim is not provided, sums over all dimensions.

Minimum compatible version: PyTorch 2.10.

Parameters:
  • self – The input tensor.

  • dim – Optional dimensions to reduce. If not provided, reduces all dimensions.

  • keepdim – Whether to retain the reduced dimensions. Defaults to false.

  • dtype – Optional output dtype. If not provided, uses the input dtype.

Returns:

A tensor containing the sum.

inline torch::stable::Tensor &torch::stable::sum_out(torch::stable::Tensor &out, const torch::stable::Tensor &self, std::optional<torch::headeronly::IntHeaderOnlyArrayRef> dim = std::nullopt, bool keepdim = false, std::optional<torch::headeronly::ScalarType> dtype = std::nullopt)#

Stable version of the sum.IntList_out op.

Computes the sum of the input tensor along the specified dimensions, storing the result in the provided output tensor. Following C++ convention, the out parameter comes first.

Minimum compatible version: PyTorch 2.10.

Parameters:
  • out – The output tensor (modified in-place).

  • self – The input tensor.

  • dim – Optional dimensions to reduce.

  • keepdim – Whether to retain the reduced dimensions. Defaults to false.

  • dtype – Optional output dtype.

Returns:

Reference to the output tensor.

inline torch::stable::Tensor torch::stable::subtract(const torch::stable::Tensor &self, const torch::stable::Tensor &other, double alpha = 1.0)#

Stable version of the subtract.Tensor op.

Subtracts the other tensor from self, with an optional scaling factor alpha. Computes: self - alpha * other.

Minimum compatible version: PyTorch 2.10.

Note

The alpha parameter is typed as double API uses double for the Scalar parameter.

Parameters:
  • self – The input tensor.

  • other – The tensor to subtract.

  • alpha – The scaling factor for other. Defaults to 1.0.

Returns:

The result of self - alpha * other.

template<class F>
inline void torch::stable::parallel_for(const int64_t begin, const int64_t end, const int64_t grain_size, const F &f)#

Stable parallel_for utility.

Provides a stable interface to at::parallel_for for parallel execution. The function f will be called with (begin, end) ranges to process in parallel. grain_size controls the minimum work size per thread for efficient parallelization.

Minimum compatible version: PyTorch 2.10.

Template Parameters:

F – The callable type

Parameters:
  • begin – The start of the iteration range.

  • end – The end of the iteration range (exclusive).

  • grain_size – The minimum number of iterations per thread.

  • f – The function to execute in parallel.

inline uint32_t torch::stable::get_num_threads()#

Gets the number of threads for the parallel backend.

Provides a stable interface to at::get_num_threads.

Minimum compatible version: PyTorch 2.10.

Returns:

The number of threads

Parallelization Utilities#

template<class F>
inline void torch::stable::parallel_for(const int64_t begin, const int64_t end, const int64_t grain_size, const F &f)

Stable parallel_for utility.

Provides a stable interface to at::parallel_for for parallel execution. The function f will be called with (begin, end) ranges to process in parallel. grain_size controls the minimum work size per thread for efficient parallelization.

Minimum compatible version: PyTorch 2.10.

Template Parameters:

F – The callable type

Parameters:
  • begin – The start of the iteration range.

  • end – The end of the iteration range (exclusive).

  • grain_size – The minimum number of iterations per thread.

  • f – The function to execute in parallel.

inline uint32_t torch::stable::get_num_threads()

Gets the number of threads for the parallel backend.

Provides a stable interface to at::get_num_threads.

Minimum compatible version: PyTorch 2.10.

Returns:

The number of threads