:github_url: https://github.com/pytorch/pytorch .. _program_listing_file_torch_library.h: Program Listing for File library.h ================================== |exhale_lsh| :ref:`Return to documentation for file ` (``torch/library.h``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp // m is a torch::Library; methods on it will define // operators in the myops namespace m.def("add", add_impl); } m.fallback(xla_fallback); } // m is a torch::Library; methods on it will define // CPU implementations of operators in the myops namespace. // It is NOT valid to call torch::Library::def() // in this context. m.impl("add", add_cpu_impl); } // You must define all of the operators for this library in // this namespace. TORCH_LIBRARY(myops, m) { // Define a operator with exactly one implementation for all backends. m.def("add(Tensor self, Tensor other) -> Tensor", &add_impl); // Define a schema for an operator, but provide no implementation // (use this syntax if you want to use the dispatcher) m.def("mul(Tensor self, Tensor other) -> Tensor"); // Provide an implementation for a defined operator (you can // provide multiple; one per backend). The dispatcher takes care of // calling the correct implementation depending on if we get a CPU // tensor or a CUDA tensor m.impl("mul", torch::kCPU, &mul_cpu_impl); m.impl("mul", torch::kCUDA, &mul_cuda_impl); } // Define implementations for operators for a non-standard backend, // e.g., XLA (valid values are entries of DispatchKey). This can // be used to define operators in a different file than the initial // TORCH_LIBRARY definition (e.g., if it is in an external library) TORCH_LIBRARY_IMPL(myops, XLA, m) { m.impl("mul", &mul_xla_impl); } #pragma once #include #include #include #include #include // Just for inferFunctionSchemaFromFunctor #include #include namespace torch { #if defined C10_MOBILE struct NoInferSchemaTag {}; #endif #define HAS_PT2_COMPLIANT_TAG // For multipy/torchdeploy use case enum class _RegisterOrVerify { REGISTER, VERIFY }; template class class_; #define HAS_IMPL_ABSTRACT_PYSTUB class TORCH_API CppFunction final { // TODO: This is morally the same thing as KernelRegistrationConfig, but it's // opaque to the user. public: template explicit CppFunction( Func* f, std::enable_if_t< c10::guts::is_function_type::value, std::nullptr_t> = nullptr) : func_(c10::KernelFunction::makeFromUnboxedRuntimeFunction(f)), cpp_signature_(c10::impl::CppSignature::make()), schema_( c10::detail::inferFunctionSchemaFromFunctor>()) {} template explicit CppFunction( FuncPtr f, std::enable_if_t< c10::is_compile_time_function_pointer::value, std::nullptr_t> = nullptr) : func_(c10::KernelFunction::makeFromUnboxedFunction(f)), cpp_signature_( c10::impl::CppSignature::make()), schema_(c10::detail::inferFunctionSchemaFromFunctor< typename FuncPtr::FuncType>()) {} template explicit CppFunction( Lambda&& f, std::enable_if_t< c10::guts::is_functor>::value, std::nullptr_t> = nullptr) : func_(c10::KernelFunction::makeFromUnboxedLambda( std::forward(f))), cpp_signature_(c10::impl::CppSignature::make()), schema_(c10::detail::inferFunctionSchemaFromFunctor< std::decay_t>()) {} #if defined C10_MOBILE template explicit CppFunction( Func* f, NoInferSchemaTag, std::enable_if_t< c10::guts::is_function_type::value, std::nullptr_t> = nullptr) : func_(c10::KernelFunction::makeFromUnboxedRuntimeFunction(f)), cpp_signature_(c10::impl::CppSignature::make()) // TODO: Don't go through WrapRuntimeKernelFunctor , schema_(nullptr), debug_() {} template explicit CppFunction( FuncPtr f, NoInferSchemaTag, std::enable_if_t< c10::is_compile_time_function_pointer::value, std::nullptr_t> = nullptr) : func_(c10::KernelFunction::makeFromUnboxedFunction(f)), cpp_signature_( c10::impl::CppSignature::make()) // TODO: Don't go through WrapRuntimeKernelFunctor , schema_(nullptr), debug_() {} template explicit CppFunction( Lambda&& f, NoInferSchemaTag, std::enable_if_t< c10::guts::is_functor>::value, std::nullptr_t> = nullptr) : func_(c10::KernelFunction::makeFromUnboxedLambda( std::forward(f))), cpp_signature_(c10::impl::CppSignature::make()) // TODO: Don't go through WrapRuntimeKernelFunctor , schema_(nullptr), debug_() {} #endif ~CppFunction(); CppFunction(const CppFunction&) = delete; CppFunction& operator=(const CppFunction&) = delete; CppFunction(CppFunction&&) noexcept = default; CppFunction& operator=(CppFunction&&) = default; static CppFunction makeFromBoxedKernel(c10::BoxedKernel kernel) { return CppFunction( c10::KernelFunction::makeFromBoxedKernel(std::move(kernel)), /* cpp_signature */ std::nullopt, // not known for boxed functions /* schema */ nullptr); } static CppFunction makeFallthrough() { return makeFromBoxedKernel(c10::BoxedKernel::makeFallthrough()); } static CppFunction makeNamedNotSupported() { return makeFromBoxedKernel(c10::BoxedKernel::makeNamedNotSupported()); } template static CppFunction makeFromBoxedFunction() { return makeFromBoxedKernel(c10::BoxedKernel::makeFromFunction()); } // Variant that takes in a boxed kernel function with a plumbed // DispatchKeySet. See Note [Plumbing Keys Through The Dispatcher] for // details. template static CppFunction makeFromBoxedFunction() { return makeFromBoxedKernel(c10::BoxedKernel::makeFromFunction()); } template static CppFunction makeFromBoxedFunctor( std::unique_ptr kernelFunctor) { return makeFromBoxedKernel( c10::BoxedKernel::makeFromFunctor(std::move(kernelFunctor))); } template < typename FuncPtr, std::enable_if_t< c10::guts::is_function_type::value, std::nullptr_t> = nullptr> static CppFunction makeFromUnboxedFunction(FuncPtr* f) { return CppFunction(f); } template < typename FuncPtr, std::enable_if_t< c10::is_compile_time_function_pointer::value, std::nullptr_t> = nullptr> static CppFunction makeFromUnboxedFunction(FuncPtr f) { return CppFunction(f); } CppFunction&& debug(std::string d) && { debug_ = std::move(d); return std::move(*this); } private: std::optional dispatch_key_; c10::KernelFunction func_; std::optional cpp_signature_; std::unique_ptr schema_; std::string debug_; // The "setter" for dispatch_key_ template friend CppFunction dispatch(c10::DispatchKey, Func&&); // The only class which actually pulls out values from CppFunction (does so // destructively, felt too lazy to write accessors that I don't even // want users to use) friend class Library; CppFunction( c10::KernelFunction func, std::optional cpp_signature, std::unique_ptr schema); }; template inline CppFunction dispatch(c10::DispatchKey k, Func&& raw_f) { CppFunction f(std::forward(raw_f)); if (k == c10::DispatchKey::CatchAll) { f.dispatch_key_ = std::nullopt; } else { f.dispatch_key_ = k; } return f; } template inline CppFunction dispatch(c10::DeviceType type, Func&& raw_f) { auto deviceTypeToDispatchKey = [](c10::DeviceType t) { switch (t) { // This list is synchronized with the k-constants in c10/core/DeviceType.h case c10::DeviceType::CPU: return c10::DispatchKey::CPU; case c10::DeviceType::CUDA: return c10::DispatchKey::CUDA; case c10::DeviceType::IPU: return c10::DispatchKey::IPU; case c10::DeviceType::XLA: return c10::DispatchKey::XLA; case c10::DeviceType::Lazy: return c10::DispatchKey::Lazy; case c10::DeviceType::XPU: return c10::DispatchKey::XPU; case c10::DeviceType::MPS: return c10::DispatchKey::MPS; case c10::DeviceType::Meta: return c10::DispatchKey::Meta; case c10::DeviceType::HIP: return c10::DispatchKey::HIP; case c10::DeviceType::MAIA: return c10::DispatchKey::MAIA; case c10::DeviceType::HPU: return c10::DispatchKey::HPU; case c10::DeviceType::MTIA: return c10::DispatchKey::MTIA; case c10::DeviceType::PrivateUse1: return c10::DispatchKey::PrivateUse1; default: TORCH_CHECK( false, "Device type ", t, " cannot be overloaded at dispatch time, " "please file a bug report explaining what you were trying to do."); } }; return dispatch(deviceTypeToDispatchKey(type), std::forward(raw_f)); } inline c10::FunctionSchema schema(const char* str, c10::AliasAnalysisKind k, bool allow_typevars=false) { c10::FunctionSchema s = torch::jit::parseSchema(str, /*allow_typevars*/allow_typevars); s.setAliasAnalysis(k); return s; } inline c10::FunctionSchema schema(const char* s, bool allow_typevars=false) { return schema(s, c10::AliasAnalysisKind::FROM_SCHEMA, allow_typevars); } inline c10::FunctionSchema&& schema(c10::FunctionSchema&& s) { return std::move(s); } namespace detail { inline std::variant constructSchemaOrName( c10::FunctionSchema&& s) { return std::move(s); } inline std::variant constructSchemaOrName( c10::OperatorName&& n) { return std::move(n); } inline std::variant constructSchemaOrName(const char* str) { auto s = torch::jit::parseSchemaOrName(str); if (std::holds_alternative(s)) { std::get(s).setAliasAnalysis( c10::AliasAnalysisKind::FROM_SCHEMA); } return s; } class TorchLibraryInit; } // namespace detail // Note [Selective build] // ~~~~~~~~~~~~~~~~~~~~~~ // In some settings, especially mobile, it is important to avoid compiling any // references to functions that you aren't actually going to use, so that they // can be eliminated by the linker. We call this capability "selective build". // // A very easy way to implement selective build which results in a lot of // boilerplate is to just add ifdef's around every registration call, but this // means you have to write a lot of extra lines of code at every registration // site, and it also means you have to define some munging scheme to map // operators to macros. // // Instead of doing this, we have a different mechanism centered around the // concept of a SelectiveStr. A selective name is like a const char* string, // except it also carries at compile time a boolean saying whether or not a // registration should actually happen or not. We then have extra overloads // which bypass registration entirely if a selective name is disabled. We do a // constexpr test to see if a operator should be enabled or not; this is // currently implemented in ATen/core/op_registration/op_allowlist.h namespace detail { // dummy class for non selected custom torchbind classes class ClassNotSelected { public: ClassNotSelected& def_pickle(...) { return *this; } ClassNotSelected& def(...) { return *this; } }; // A SelectiveStr is like a const char*, except that it also comes // with a type brand that says whether or not the name is enabled or // not. If the string is disabled, then (at compile time) we DON'T generate // a registration call for it. This class is not intended to be called // directly; use TORCH_SELECTIVE_NAME or TORCH_SELECTIVE_SCHEMA macros below // to create it. template class SelectiveStr { public: constexpr explicit SelectiveStr(const char* name) : name_(name) {} constexpr operator const char*() { return name_; } private: const char* name_; }; #define TORCH_SELECTIVE_CLASS(n) \ torch::detail::SelectiveStr(n) #define TORCH_SELECTIVE_NAME(n) \ torch::detail::SelectiveStr(n) #define TORCH_SELECTIVE_SCHEMA(n) \ torch::detail::SelectiveStr(n) } // namespace detail class TORCH_API Library final { public: enum Kind { DEF, // from TORCH_LIBRARY (no qualifier) IMPL, FRAGMENT, }; Library( Kind kind, std::string ns, std::optional k, const char* file, uint32_t line); Library(const Library&) = delete; Library& operator=(const Library&) = delete; Library(Library&&) = default; Library& operator=(Library&&) = default; ~Library() = default; // Some notes about the API design here. We had the following constraints: // // - We need to support multiple "types" of arguments for schema and // functions (e.g., unnamed lambda types, regular functions, const char*, // fully instantiated schemas) // - We don't want to write exponentially many overloads // - We don't want to rely on implicit conversion to a common type, // because the C++ compiler will only be willing to do a single // implicit conversion (reducing the set of valid types which you // can invoke with); also error messages are worse when an implicit // conversion is not selected (as the compiler will not explain // why it didn't select an implicit conversion; this is different // from overloads where it will explain each candidate overload and // why it didn't apply) // // To solve all of these constraints at the same time, we use a trick taken // from the pybind11 library: template over the argument in the user visible // API, and inside of the templated function explicitly call an overloaded // function to resolve the argument to a real type. You get the good error // messages from overloads, but at the same time you only need to write the // overload for any given argument type once. Library& def( c10::FunctionSchema&& s, const std::vector& tags = {}, _RegisterOrVerify rv = _RegisterOrVerify::REGISTER) & { return _def(std::move(s), nullptr, tags, rv); } Library& def( const char* raw_schema, const std::vector& tags = {}, _RegisterOrVerify rv = _RegisterOrVerify::REGISTER) & { return _def(schema(raw_schema), nullptr, tags, rv); } Library& set_python_module(const char* pymodule, const char* context = "") { python_module_ = {pymodule, context}; return *this; } Library& impl_abstract_pystub(const char* pymodule, const char* context = "") { return set_python_module(pymodule, context); } template Library& def(NameOrSchema&& raw_name_or_schema, Func&& raw_f, const std::vector& tags = {}) & { CppFunction f(std::forward(raw_f)); return _def( detail::constructSchemaOrName( ::std::forward(raw_name_or_schema)), ::std::move(f), tags); } template Library& impl( Name name, Func&& raw_f, _RegisterOrVerify rv = _RegisterOrVerify::REGISTER) & { // TODO: need to raise an error when you impl a function that has a // catch all def #if defined C10_MOBILE CppFunction f(std::forward(raw_f), NoInferSchemaTag()); #else CppFunction f(std::forward(raw_f)); #endif return _impl(name, std::move(f), rv); } #if defined C10_MOBILE // Note: This overload is needed only for C10_MOBILE, since the automatically // defined copy constructor for the CppFunction doesn't have the additional // NoInferSchemaTag argument. We define the overload for the impl() function // to accept a CppFunction&& argument. The already constructed CppFunction // object may or may not have the inferred schema, but it doesn't matter // for our purposes since if it already has the inferred schema, then we // might as well just pass it through directly. // template Library& impl(Name name, CppFunction&& raw_f) & { // TODO: need to raise an error when you impl a function that has a // catch all def CppFunction f(std::forward(raw_f)); return _impl(name, std::move(f)); } #endif // Helper for getting an OperatorName for a const char*. You probably // don't need this. c10::OperatorName _resolve(const char* name) const; template Library& impl(Name name, Dispatch&& key, Func&& raw_f) & { return impl( name, dispatch(std::forward(key), std::forward(raw_f))); } template Library& impl_UNBOXED(Name /*name*/, Func* /*raw_f*/) & { static_assert( c10::guts::false_t(), ".impl_UNBOXED(...) was removed. Please use .impl(...) instead."); return *this; } // These overloads cover cases when a SelectiveStr (see Note [Selective // build]) has been disabled at compile time. In that case, don't generate // any code referencing the passed in functions at all. Library& def(detail::SelectiveStr, const std::vector& tags [[maybe_unused]] = {}) & { return *this; } Library& def(detail::SelectiveStr raw_schema, const std::vector& tags = {}) & { return def(raw_schema.operator const char*(), tags); } template Library& def(detail::SelectiveStr, Func&& /*raw_f*/, const std::vector& tags [[maybe_unused]] = {}) & { return *this; } template Library& def(detail::SelectiveStr raw_name_or_schema, Func&& raw_f, const std::vector& tags = {}) & { return def( raw_name_or_schema.operator const char*(), std::forward(raw_f), tags); } template // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward) Library& impl(detail::SelectiveStr, Func&& /*raw_f*/) & { return *this; } template Library& impl( detail::SelectiveStr, // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward) Dispatch&& /*key*/, // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward) Func&& /*raw_f*/) & { return *this; } template Library& impl_UNBOXED( detail::SelectiveStr /*name*/, Func* /*raw_f*/) & { static_assert( c10::guts::false_t(), ".impl_UNBOXED(...) was removed. Please use .impl(...) instead."); return *this; } template Library& impl(detail::SelectiveStr name, Func&& raw_f) & { return impl(name.operator const char*(), std::forward(raw_f)); } template Library& impl( detail::SelectiveStr name, Dispatch&& key, Func&& raw_f) & { return impl( name.operator const char*(), std::forward(key), std::forward(raw_f)); } template Library& impl_UNBOXED( detail::SelectiveStr /*name*/, Func* /*raw_f*/) & { static_assert( c10::guts::false_t(), ".impl_UNBOXED(...) was removed. Please use .impl(...) instead."); return *this; } template Library& fallback(Func&& raw_f) & { CppFunction f((std::forward(raw_f))); return _fallback(std::move(f)); } template inline torch::class_ class_(const std::string& className); // These overloads enable the use of selective build on classes registered // within a library. The API is the same as before with 1 minor change. // Instead of m.class_("foo") you instead do // m.class_(TORCH_SELECTIVE_CLASS("foo")) template inline torch::class_ class_(detail::SelectiveStr className); template inline detail::ClassNotSelected class_(detail::SelectiveStr className); // De-registers all registrations created with this Library void reset(); private: Kind kind_; std::optional ns_; std::optional dispatch_key_; std::optional> python_module_; const char* file_; uint32_t line_; std::vector registrars_; friend class detail::TorchLibraryInit; // Non-user visible actual implementations of functions. These aren't // public because we only implement & qualifier and not && qualifier Library& _def( c10::FunctionSchema&& schema, c10::OperatorName* out_name = nullptr, const std::vector& tags = {}, _RegisterOrVerify rv = _RegisterOrVerify::REGISTER) &; Library& _def( std::variant&&, CppFunction&& f, const std::vector& tags = {}) &; Library& _impl( const char* name, CppFunction&& f, _RegisterOrVerify rv = _RegisterOrVerify::REGISTER) &; Library& _fallback(CppFunction&& f) &; at::OperatorName _parseNameForLib(const char* name_str) const; }; #if defined(TORCH_LIBRARY_THREAD_UNSAFE_LAZY_INIT) && defined(C10_MOBILE) void initialize_torch_libraries(); #endif namespace detail { #if defined(TORCH_LIBRARY_THREAD_UNSAFE_LAZY_INIT) && defined(C10_MOBILE) // This is an experimental feature to defer TorchLibraryInit cost to run either // at model load time, or when a client application explicitly calls // torch::initialize_torch_libraries(). // // This is not thread safe, the client is required to ensure that libraries // containing TORCH_LIBRARY initializers are loaded in a thread safe manner. extern std::vector torch_library_initializers; class TorchLibraryInit final { private: using InitFn = void(Library&); Library::Kind kind; InitFn* init_function; const char* ns; std::optional key; const char* file; uint32_t line; std::unique_ptr lib = nullptr; public: TorchLibraryInit( Library::Kind kind, InitFn* fn, const char* ns, std::optional k, const char* file, uint32_t line) : kind(kind), init_function(fn), ns(ns), key(k), file(file), line(line) { torch_library_initializers.push_back(this); } void initialize() { lib = std::unique_ptr(new Library(kind, ns, key, file, line)); init_function(*lib); } }; #else class TorchLibraryInit final { private: using InitFn = void(Library&); Library lib_; public: TorchLibraryInit( Library::Kind kind, InitFn* fn, const char* ns, std::optional k, const char* file, uint32_t line) : lib_(kind, ns, k, file, line) { fn(lib_); } }; #endif } // namespace detail } // namespace torch // NB: The EXACT NAMING of the initializer functions (e.g., // TORCH_LIBRARY_init_aten) matters for the code analyzer; // see the regexes at tools/code_analyzer/run_analyzer.sh #define TORCH_LIBRARY(ns, m) \ static void TORCH_LIBRARY_init_##ns(torch::Library&); \ static const torch::detail::TorchLibraryInit TORCH_LIBRARY_static_init_##ns( \ torch::Library::DEF, \ &TORCH_LIBRARY_init_##ns, \ #ns, \ std::nullopt, \ __FILE__, \ __LINE__); \ void TORCH_LIBRARY_init_##ns(torch::Library& m) #define TORCH_LIBRARY_FRAGMENT(ns, m) _TORCH_LIBRARY_FRAGMENT(ns, m, C10_UID) #define _TORCH_LIBRARY_FRAGMENT(ns, m, uid) \ static void C10_CONCATENATE( \ TORCH_LIBRARY_FRAGMENT_init_##ns##_, uid)(torch::Library&); \ static const torch::detail::TorchLibraryInit C10_CONCATENATE( \ TORCH_LIBRARY_FRAGMENT_static_init_##ns##_, uid)( \ torch::Library::FRAGMENT, \ &C10_CONCATENATE(TORCH_LIBRARY_FRAGMENT_init_##ns##_, uid), \ #ns, \ std::nullopt, \ __FILE__, \ __LINE__); \ void C10_CONCATENATE( \ TORCH_LIBRARY_FRAGMENT_init_##ns##_, uid)(torch::Library & m) // NB: if the dispatch key is not whitelisted, we simply omit the Library // call entirely #define TORCH_LIBRARY_IMPL(ns, k, m) _TORCH_LIBRARY_IMPL(ns, k, m, C10_UID) #define _TORCH_LIBRARY_IMPL(ns, k, m, uid) \ static void C10_CONCATENATE( \ TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid)(torch::Library&); \ static const torch::detail::TorchLibraryInit C10_CONCATENATE( \ TORCH_LIBRARY_IMPL_static_init_##ns##_##k##_, uid)( \ torch::Library::IMPL, \ &C10_CONCATENATE(TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid), \ #ns, \ std::make_optional(c10::DispatchKey::k), \ __FILE__, \ __LINE__); \ void C10_CONCATENATE( \ TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid)(torch::Library & m) // These are variants of the macros above which are to be used for testing (they // don't setup the static initializer, so you can control the visibility of // the allocated library yourself). // // DO NOT use these in production code, they are NOT understood by the // code analyzer and will be incorrectly analyzed in those situations. #define MAKE_TORCH_LIBRARY(ns) \ torch::Library(torch::Library::DEF, #ns, std::nullopt, __FILE__, __LINE__) #define MAKE_TORCH_LIBRARY_IMPL(ns, k) \ torch::Library( \ torch::Library::IMPL, \ #ns, \ std::make_optional(c10::DispatchKey::k), \ __FILE__, \ __LINE__) // Make the custom class API visible, so it is available from // torch::Library. #include