Program Listing for File library.h#
↰ Return to documentation for file (torch/csrc/stable/library.h)
#pragma once
// this file can only have stable stuff! Akin to shim.h
// but unlike shim.h, this file can contain header-only C++
// code for better UX.
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/csrc/stable/c/shim.h>
#include <torch/headeronly/macros/Macros.h>
#include <torch/headeronly/util/Metaprogramming.h>
// Technically, this file doesn't use anything from stableivalue_conversions.h,
// but we need to include it here as the contents of stableivalue_conversions.h
// used to live here and so we need to expose them for backwards compatibility.
#include <torch/csrc/stable/stableivalue_conversions.h>
#include <torch/csrc/stable/version.h>
HIDDEN_NAMESPACE_BEGIN(torch, stable, detail)
class StableLibrary final {
private:
TorchLibraryHandle lib_;
public:
enum class Kind {
DEF,
IMPL,
FRAGMENT,
};
// constructor
StableLibrary(
Kind kind,
const char* ns,
const char* k,
const char* file,
uint32_t line) {
if (kind == Kind::IMPL) {
aoti_torch_library_init_impl(ns, k, file, line, &lib_);
} else if (kind == Kind::DEF) {
aoti_torch_library_init_def(ns, file, line, &lib_);
} else { // kind == FRAGMENT
aoti_torch_library_init_fragment(ns, file, line, &lib_);
}
}
// do not permit copy
StableLibrary(const StableLibrary&) = delete;
StableLibrary& operator=(const StableLibrary&) = delete;
// do not permit move
StableLibrary(StableLibrary&& other) = delete;
StableLibrary& operator=(StableLibrary&& other) = delete;
~StableLibrary() {
aoti_torch_delete_library_object(lib_);
}
// corresponds to a limited, stable version of torch::library::impl()
// Inputs:
// name: the name of the function to implement
// fn: a boxed function with schema
// (StableIValue* stack, uint64_t num_inputs, uint64_t num_outputs) ->
// void
// fn should follow the calling convention of our boxed kernels that convert
// to IValues. fn will be called with a StableIValue* array of length
// max(num_inputs, num_outputs), where the first num_inputs entries are
// populated with inputs. fn is responsible for stealing the memory of the
// inputs, in effect "popping" them off the stack, and then populating the
// stack with StableIValue outputs. Concretely, fn should:
// 1. read StableIValue inputs from the given stack
// 2. convert the inputs to the proper types
// 3. call the function corresponding to name with the inputs
// 4. convert the outputs to StableIValues
// 5. populate the now empty stack with StableIValue outputs
// If the operation corresponding to name takes in 4 inputs and returns 2
// outputs, fn should expect stack to contain 4 StableIValues:
// [stable_arg1, stable_arg2, stable_arg3, stable_arg4]
// to end, fn should fill the stack with 2 StableIValues representing outputs:
// [stable_ret1, stable_ret2, -, -]
StableLibrary& impl(
const char* name,
void (*fn)(StableIValue*, uint64_t, uint64_t)) {
#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
torch_library_impl(lib_, name, fn, TORCH_ABI_VERSION);
#else
aoti_torch_library_impl(lib_, name, fn);
#endif
return *this;
}
// corresponds to a limited, stable version of torch::library::def()
StableLibrary& def(const char* schema) {
aoti_torch_library_def(lib_, schema);
return *this;
}
};
class StableTorchLibraryInit final {
private:
using InitFn = void(StableLibrary&);
StableLibrary lib_;
public:
StableTorchLibraryInit(
StableLibrary::Kind kind,
InitFn* fn,
const char* ns,
const char* k,
const char* file,
uint32_t line)
: lib_(kind, ns, k, file, line) {
fn(lib_);
}
};
// type mapper: since to<HeaderOnlyArrayRef<T>> cannot exist,
// we map that to to<std::vector<T>> to preserve ownership semantics.
// note that unbox_type_t is used to convert ParamTypes, so that
// the tuple holding the arguments will have proper ownership too.
template <typename T>
struct UnboxType {
using type = T;
};
template <typename T>
struct UnboxType<torch::headeronly::HeaderOnlyArrayRef<T>> {
using type = std::vector<T>;
};
template <typename T>
using unbox_type_t = typename UnboxType<T>::type;
template <class... T, std::size_t... I>
std::tuple<T...> unbox_to_tuple_impl(
StableIValue* stack,
std::index_sequence<I...>) {
return std::make_tuple(to<T>(stack[I])...);
}
template <class... T>
std::tuple<T...> unbox_to_tuple(StableIValue* stack) {
return unbox_to_tuple_impl<T...>(
stack, std::make_index_sequence<sizeof...(T)>());
}
template <class... T, std::size_t... I>
void box_from_tuple_impl(
StableIValue* stack,
std::tuple<T...> vals,
std::index_sequence<I...>) {
((stack[I] = from<T>(std::get<I>(vals))), ...);
}
template <class... T>
void box_from_tuple(StableIValue* stack, std::tuple<T...> vals) {
box_from_tuple_impl<T...>(
stack, vals, std::make_index_sequence<sizeof...(T)>());
}
template <
typename ReturnType,
typename ParameterTypeList,
typename FuncT,
FuncT* func>
struct boxer_impl {
static_assert(
torch::headeronly::guts::false_t<ReturnType>::value,
"Unsupported function schema for TORCH_BOX.");
};
// Multiple returns
template <
typename... ReturnTypes,
typename... ParameterTypes,
typename FuncT,
FuncT* func>
struct boxer_impl<
std::tuple<ReturnTypes...>,
torch::headeronly::guts::typelist::typelist<ParameterTypes...>,
FuncT,
func> {
static void boxed_fn(
StableIValue* stack,
uint64_t num_args,
uint64_t num_outputs) {
STD_TORCH_CHECK(
num_args == sizeof...(ParameterTypes),
"Registered schema has ",
num_args,
" args, but the kernel to box has ",
sizeof...(ParameterTypes));
STD_TORCH_CHECK(
num_outputs == sizeof...(ReturnTypes),
"Registered schema has ",
num_outputs,
" outputs, but the kernel to box has ",
sizeof...(ReturnTypes));
std::tuple<unbox_type_t<ParameterTypes>...> args =
unbox_to_tuple<unbox_type_t<ParameterTypes>...>(stack);
auto res = std::apply(func, args);
box_from_tuple<ReturnTypes...>(stack, res);
}
};
// Single return
template <
typename ReturnType,
typename... ParameterTypes,
typename FuncT,
FuncT* func>
struct boxer_impl<
ReturnType,
torch::headeronly::guts::typelist::typelist<ParameterTypes...>,
FuncT,
func> {
static void boxed_fn(
StableIValue* stack,
uint64_t num_args,
uint64_t num_outputs) {
STD_TORCH_CHECK(
num_args == sizeof...(ParameterTypes),
"Registered schema has ",
num_args,
" args, but the kernel to box has ",
sizeof...(ParameterTypes));
STD_TORCH_CHECK(
num_outputs == 1,
"Registered schema has ",
num_outputs,
" outputs, but the kernel to box has ",
1);
std::tuple<unbox_type_t<ParameterTypes>...> args =
unbox_to_tuple<unbox_type_t<ParameterTypes>...>(stack);
auto res = std::apply(func, args);
stack[0] = from<ReturnType>(res);
}
};
// No/void return
template <typename... ParameterTypes, typename FuncT, FuncT* func>
struct boxer_impl<
void,
torch::headeronly::guts::typelist::typelist<ParameterTypes...>,
FuncT,
func> {
static void boxed_fn(
StableIValue* stack,
uint64_t num_args,
uint64_t num_outputs) {
STD_TORCH_CHECK(
num_args == sizeof...(ParameterTypes),
"Registered schema has ",
num_args,
" args, but the kernel to box has ",
sizeof...(ParameterTypes));
STD_TORCH_CHECK(
num_outputs == 0,
"Registered schema has ",
num_outputs,
" outputs, but the kernel to box has ",
0);
std::tuple<unbox_type_t<ParameterTypes>...> args =
unbox_to_tuple<unbox_type_t<ParameterTypes>...>(stack);
std::apply(func, args);
}
};
template <typename FuncT, FuncT* func>
struct boxer {
using FunctionTraits =
torch::headeronly::guts::infer_function_traits_t<FuncT>;
static void boxed_fn(
StableIValue* stack,
uint64_t num_args,
uint64_t num_outputs) {
boxer_impl<
typename FunctionTraits::return_type,
typename FunctionTraits::parameter_types,
FuncT,
func>::boxed_fn(stack, num_args, num_outputs);
}
};
HIDDEN_NAMESPACE_END(torch, stable, detail)
#define TORCH_BOX(func) \
torch::stable::detail::boxer< \
std::remove_pointer_t<std::remove_reference_t<decltype(func)>>, \
(func)>::boxed_fn
// macros copied from c10/macros/Macros.h
#ifdef __COUNTER__
#define STABLE_UID __COUNTER__
#else
#define STABLE_UID __LINE__
#endif
#define STABLE_CONCATENATE_IMPL(s1, s2) s1##s2
#define STABLE_CONCATENATE(s1, s2) STABLE_CONCATENATE_IMPL(s1, s2)
// end of macros copied from c10/macros/Macros.h
#define STABLE_TORCH_LIBRARY_IMPL(ns, k, m) \
_STABLE_TORCH_LIBRARY_IMPL(ns, k, m, STABLE_UID)
#define _STABLE_TORCH_LIBRARY_IMPL(ns, k, m, uid) \
static void STABLE_CONCATENATE( \
STABLE_TORCH_LIBRARY_IMPL_init_##ns##_##k##_, \
uid)(torch::stable::detail::StableLibrary&); \
static const torch::stable::detail::StableTorchLibraryInit \
STABLE_CONCATENATE( \
STABLE_TORCH_LIBRARY_IMPL_static_init_##ns##_##k##_, uid)( \
torch::stable::detail::StableLibrary::Kind::IMPL, \
&STABLE_CONCATENATE( \
STABLE_TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid), \
#ns, \
#k, \
__FILE__, \
__LINE__); \
void STABLE_CONCATENATE(STABLE_TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid)( \
torch::stable::detail::StableLibrary & m)
#define STABLE_TORCH_LIBRARY(ns, m) \
static void STABLE_TORCH_LIBRARY_init_##ns( \
torch::stable::detail::StableLibrary&); \
static const torch::stable::detail::StableTorchLibraryInit \
STABLE_TORCH_LIBRARY_static_init_##ns( \
torch::stable::detail::StableLibrary::Kind::DEF, \
&STABLE_TORCH_LIBRARY_init_##ns, \
#ns, \
nullptr, \
__FILE__, \
__LINE__); \
void STABLE_TORCH_LIBRARY_init_##ns(torch::stable::detail::StableLibrary& m)
#define STABLE_TORCH_LIBRARY_FRAGMENT(ns, m) \
_STABLE_TORCH_LIBRARY_FRAGMENT(ns, m, STABLE_UID)
#define _STABLE_TORCH_LIBRARY_FRAGMENT(ns, m, uid) \
static void STABLE_CONCATENATE( \
STABLE_TORCH_LIBRARY_FRAGMENT_init_##ns##_, \
uid)(torch::stable::detail::StableLibrary&); \
static const torch::stable::detail::StableTorchLibraryInit \
STABLE_CONCATENATE( \
STABLE_TORCH_LIBRARY_FRAGMENT_static_init_##ns##_, uid)( \
torch::stable::detail::StableLibrary::Kind::FRAGMENT, \
&STABLE_CONCATENATE( \
STABLE_TORCH_LIBRARY_FRAGMENT_init_##ns##_, uid), \
#ns, \
nullptr, \
__FILE__, \
__LINE__); \
void STABLE_CONCATENATE(STABLE_TORCH_LIBRARY_FRAGMENT_init_##ns##_, uid)( \
torch::stable::detail::StableLibrary & m)