:github_url: https://github.com/pytorch/pytorch .. _program_listing_file_torch_csrc_stable_library.h: Program Listing for File library.h ================================== |exhale_lsh| :ref:`Return to documentation for file ` (``torch/csrc/stable/library.h``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #pragma once // this file can only have stable stuff! Akin to shim.h // but unlike shim.h, this file can contain header-only C++ // code for better UX. #include #include // Technically, this file doesn't use anything from stableivalue_conversions.h, // but we need to include it here as the contents of stableivalue_conversions.h // used to live here and so we need to expose them for backwards compatibility. #include HIDDEN_NAMESPACE_BEGIN(torch, stable, detail) class StableLibrary final { private: TorchLibraryHandle lib_; public: enum class Kind { DEF, IMPL, FRAGMENT, }; // constructor StableLibrary( Kind kind, const char* ns, const char* k, const char* file, uint32_t line) { if (kind == Kind::IMPL) { aoti_torch_library_init_impl(ns, k, file, line, &lib_); } else if (kind == Kind::DEF) { aoti_torch_library_init_def(ns, file, line, &lib_); } else { // kind == FRAGMENT aoti_torch_library_init_fragment(ns, file, line, &lib_); } } // do not permit copy StableLibrary(const StableLibrary&) = delete; StableLibrary& operator=(const StableLibrary&) = delete; // do not permit move StableLibrary(StableLibrary&& other) = delete; StableLibrary& operator=(StableLibrary&& other) = delete; ~StableLibrary() { aoti_torch_delete_library_object(lib_); } // corresponds to a limited, stable version of torch::library::impl() // Inputs: // name: the name of the function to implement // fn: a boxed function with schema // (StableIValue* stack, uint64_t num_inputs, uint64_t num_outputs) -> // void // fn should follow the calling convention of our boxed kernels that convert // to IValues. fn will be called with a StableIValue* array of length // max(num_inputs, num_outputs), where the first num_inputs entries are // populated with inputs. fn is responsible for stealing the memory of the // inputs, in effect "popping" them off the stack, and then populating the // stack with StableIValue outputs. Concretely, fn should: // 1. read StableIValue inputs from the given stack // 2. convert the inputs to the proper types // 3. call the function corresponding to name with the inputs // 4. convert the outputs to StableIValues // 5. populate the now empty stack with StableIValue outputs // If the operation corresponding to name takes in 4 inputs and returns 2 // outputs, fn should expect stack to contain 4 StableIValues: // [stable_arg1, stable_arg2, stable_arg3, stable_arg4] // to end, fn should fill the stack with 2 StableIValues representing outputs: // [stable_ret1, stable_ret2, -, -] StableLibrary& impl( const char* name, void (*fn)(StableIValue*, uint64_t, uint64_t)) { aoti_torch_library_impl(lib_, name, fn); return *this; } // corresponds to a limited, stable version of torch::library::def() StableLibrary& def(const char* schema) { aoti_torch_library_def(lib_, schema); return *this; } }; class StableTorchLibraryInit final { private: using InitFn = void(StableLibrary&); StableLibrary lib_; public: StableTorchLibraryInit( StableLibrary::Kind kind, InitFn* fn, const char* ns, const char* k, const char* file, uint32_t line) : lib_(kind, ns, k, file, line) { fn(lib_); } }; HIDDEN_NAMESPACE_END(torch, stable, detail) // macros copied from c10/macros/Macros.h #ifdef __COUNTER__ #define STABLE_UID __COUNTER__ #else #define STABLE_UID __LINE__ #endif #define STABLE_CONCATENATE_IMPL(s1, s2) s1##s2 #define STABLE_CONCATENATE(s1, s2) STABLE_CONCATENATE_IMPL(s1, s2) // end of macros copied from c10/macros/Macros.h #define STABLE_TORCH_LIBRARY_IMPL(ns, k, m) \ _STABLE_TORCH_LIBRARY_IMPL(ns, k, m, STABLE_UID) #define _STABLE_TORCH_LIBRARY_IMPL(ns, k, m, uid) \ static void STABLE_CONCATENATE( \ STABLE_TORCH_LIBRARY_IMPL_init_##ns##_##k##_, \ uid)(torch::stable::detail::StableLibrary&); \ static const torch::stable::detail::StableTorchLibraryInit \ STABLE_CONCATENATE( \ STABLE_TORCH_LIBRARY_IMPL_static_init_##ns##_##k##_, uid)( \ torch::stable::detail::StableLibrary::Kind::IMPL, \ &STABLE_CONCATENATE( \ STABLE_TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid), \ #ns, \ #k, \ __FILE__, \ __LINE__); \ void STABLE_CONCATENATE(STABLE_TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid)( \ torch::stable::detail::StableLibrary & m) #define STABLE_TORCH_LIBRARY(ns, m) \ static void STABLE_TORCH_LIBRARY_init_##ns( \ torch::stable::detail::StableLibrary&); \ static const torch::stable::detail::StableTorchLibraryInit \ STABLE_TORCH_LIBRARY_static_init_##ns( \ torch::stable::detail::StableLibrary::Kind::DEF, \ &STABLE_TORCH_LIBRARY_init_##ns, \ #ns, \ nullptr, \ __FILE__, \ __LINE__); \ void STABLE_TORCH_LIBRARY_init_##ns(torch::stable::detail::StableLibrary& m) #define STABLE_TORCH_LIBRARY_FRAGMENT(ns, m) \ _STABLE_TORCH_LIBRARY_FRAGMENT(ns, m, STABLE_UID) #define _STABLE_TORCH_LIBRARY_FRAGMENT(ns, m, uid) \ static void STABLE_CONCATENATE( \ STABLE_TORCH_LIBRARY_FRAGMENT_init_##ns##_, \ uid)(torch::stable::detail::StableLibrary&); \ static const torch::stable::detail::StableTorchLibraryInit \ STABLE_CONCATENATE( \ STABLE_TORCH_LIBRARY_FRAGMENT_static_init_##ns##_, uid)( \ torch::stable::detail::StableLibrary::Kind::FRAGMENT, \ &STABLE_CONCATENATE( \ STABLE_TORCH_LIBRARY_FRAGMENT_init_##ns##_, uid), \ #ns, \ nullptr, \ __FILE__, \ __LINE__); \ void STABLE_CONCATENATE(STABLE_TORCH_LIBRARY_FRAGMENT_init_##ns##_, uid)( \ torch::stable::detail::StableLibrary & m)