:github_url: https://github.com/pytorch/pytorch .. _program_listing_file_torch_custom_class.h: Program Listing for File custom_class.h ======================================= |exhale_lsh| :ref:`Return to documentation for file ` (``torch/custom_class.h``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #pragma once #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace torch { template detail::types init() { return detail::types{}; } template struct InitLambda { Func f; }; template decltype(auto) init(Func&& f) { using InitTraits = c10::guts::infer_function_traits_t>; using ParameterTypeList = typename InitTraits::parameter_types; InitLambda init{std::forward(f)}; return init; } template class class_ : public ::torch::detail::class_base { static_assert( std::is_base_of_v, "torch::class_ requires T to inherit from CustomClassHolder"); public: explicit class_( const std::string& namespaceName, const std::string& className, std::string doc_string = "") : class_base( namespaceName, className, std::move(doc_string), typeid(c10::intrusive_ptr), typeid(c10::tagged_capsule)) {} template class_& def( torch::detail::types, std::string doc_string = "", std::initializer_list default_args = {}) { // Used in combination with // torch::init<...>() auto func = [](c10::tagged_capsule self, Types... args) { auto classObj = c10::make_intrusive(args...); auto object = self.ivalue.toObject(); object->setSlot(0, c10::IValue::make_capsule(std::move(classObj))); }; defineMethod( "__init__", std::move(func), std::move(doc_string), default_args); return *this; } // Used in combination with torch::init([]lambda(){......}) template class_& def( InitLambda> init, std::string doc_string = "", std::initializer_list default_args = {}) { auto init_lambda_wrapper = [func = std::move(init.f)]( c10::tagged_capsule self, ParameterTypes... arg) { c10::intrusive_ptr classObj = std::invoke(func, std::forward(arg)...); auto object = self.ivalue.toObject(); object->setSlot(0, c10::IValue::make_capsule(classObj)); }; defineMethod( "__init__", std::move(init_lambda_wrapper), std::move(doc_string), default_args); return *this; } template class_& def( std::string name, Func f, std::string doc_string = "", std::initializer_list default_args = {}) { auto wrapped_f = detail::wrap_func(std::move(f)); defineMethod( std::move(name), std::move(wrapped_f), std::move(doc_string), default_args); return *this; } template class_& def_static(std::string name, Func func, std::string doc_string = "") { auto qualMethodName = qualClassName + "." + name; auto schema = c10::inferFunctionSchemaSingleReturn(std::move(name), ""); auto wrapped_func = [func = std::move(func)](jit::Stack& stack) mutable -> void { using RetType = typename c10::guts::infer_function_traits_t::return_type; detail::BoxedProxy()(stack, func); }; auto method = std::make_unique( std::move(qualMethodName), std::move(schema), std::move(wrapped_func), std::move(doc_string)); classTypePtr->addStaticMethod(method.get()); registerCustomClassMethod(std::move(method)); return *this; } template class_& def_property( const std::string& name, GetterFunc getter_func, SetterFunc setter_func, std::string doc_string = "") { torch::jit::Function* getter{}; torch::jit::Function* setter{}; auto wrapped_getter = detail::wrap_func(std::move(getter_func)); getter = defineMethod(name + "_getter", wrapped_getter, doc_string); auto wrapped_setter = detail::wrap_func(std::move(setter_func)); setter = defineMethod(name + "_setter", wrapped_setter, doc_string); classTypePtr->addProperty(name, getter, setter); return *this; } template class_& def_property( const std::string& name, GetterFunc getter_func, std::string doc_string = "") { torch::jit::Function* getter{}; auto wrapped_getter = detail::wrap_func(std::move(getter_func)); getter = defineMethod(name + "_getter", wrapped_getter, doc_string); classTypePtr->addProperty(name, getter, nullptr); return *this; } template class_& def_readwrite(const std::string& name, T CurClass::*field) { auto getter_func = [field = field](const c10::intrusive_ptr& self) { return self.get()->*field; }; auto setter_func = [field = field]( const c10::intrusive_ptr& self, T value) { self.get()->*field = value; }; return def_property(name, getter_func, setter_func); } template class_& def_readonly(const std::string& name, T CurClass::*field) { auto getter_func = [field = std::move(field)](const c10::intrusive_ptr& self) { return self.get()->*field; }; return def_property(name, getter_func); } class_& _def_unboxed( const std::string& name, std::function func, c10::FunctionSchema schema, std::string doc_string = "") { auto method = std::make_unique( qualClassName + "." + name, std::move(schema), std::move(func), std::move(doc_string)); classTypePtr->addMethod(method.get()); registerCustomClassMethod(std::move(method)); return *this; } template class_& def_pickle(GetStateFn&& get_state, SetStateFn&& set_state) { static_assert( c10::guts::is_stateless_lambda>::value && c10::guts::is_stateless_lambda>::value, "def_pickle() currently only supports lambdas as " "__getstate__ and __setstate__ arguments."); def("__getstate__", std::forward(get_state)); // __setstate__ needs to be registered with some custom handling: // We need to wrap the invocation of the user-provided function // such that we take the return value (i.e. c10::intrusive_ptr) // and assign it to the `capsule` attribute. using SetStateTraits = c10::guts::infer_function_traits_t>; using SetStateArg = typename c10::guts::typelist::head_t< typename SetStateTraits::parameter_types>; auto setstate_wrapper = [set_state = std::forward(set_state)]( c10::tagged_capsule self, SetStateArg arg) { c10::intrusive_ptr classObj = std::invoke(set_state, std::move(arg)); auto object = self.ivalue.toObject(); object->setSlot(0, c10::IValue::make_capsule(classObj)); }; defineMethod( "__setstate__", detail::wrap_func( std::move(setstate_wrapper))); // type validation auto getstate_schema = classTypePtr->getMethod("__getstate__").getSchema(); #ifndef STRIP_ERROR_MESSAGES auto format_getstate_schema = [&getstate_schema]() { std::stringstream ss; ss << getstate_schema; return ss.str(); }; #endif TORCH_CHECK( getstate_schema.arguments().size() == 1, "__getstate__ should take exactly one argument: self. Got: ", format_getstate_schema()); auto first_arg_type = getstate_schema.arguments().at(0).type(); TORCH_CHECK( *first_arg_type == *classTypePtr, "self argument of __getstate__ must be the custom class type. Got ", first_arg_type->repr_str()); TORCH_CHECK( getstate_schema.returns().size() == 1, "__getstate__ should return exactly one value for serialization. Got: ", format_getstate_schema()); auto ser_type = getstate_schema.returns().at(0).type(); auto setstate_schema = classTypePtr->getMethod("__setstate__").getSchema(); auto arg_type = setstate_schema.arguments().at(1).type(); TORCH_CHECK( ser_type->isSubtypeOf(*arg_type), "__getstate__'s return type should be a subtype of " "input argument of __setstate__. Got ", ser_type->repr_str(), " but expected ", arg_type->repr_str()); return *this; } private: template torch::jit::Function* defineMethod( std::string name, Func func, std::string doc_string = "", std::initializer_list default_args = {}) { auto qualMethodName = qualClassName + "." + name; auto schema = c10::inferFunctionSchemaSingleReturn(std::move(name), ""); // If default values are provided for function arguments, there must be // none (no default values) or default values for all function // arguments, except for self. This is because argument names are not // extracted by inferFunctionSchemaSingleReturn, and so there must be a // torch::arg instance in default_args even for arguments that do not // have an actual default value provided. TORCH_CHECK( default_args.size() == 0 || default_args.size() == schema.arguments().size() - 1, "Default values must be specified for none or all arguments"); // If there are default args, copy the argument names and default values to // the function schema. if (default_args.size() > 0) { schema = withNewArguments(schema, default_args); } auto wrapped_func = [func = std::move(func)](jit::Stack& stack) mutable -> void { // TODO: we need to figure out how to profile calls to custom functions // like this! Currently can't do it because the profiler stuff is in // libtorch and not ATen using RetType = typename c10::guts::infer_function_traits_t::return_type; detail::BoxedProxy()(stack, func); }; auto method = std::make_unique( qualMethodName, std::move(schema), std::move(wrapped_func), std::move(doc_string)); // Register the method here to keep the Method alive. // ClassTypes do not hold ownership of their methods (normally it // those are held by the CompilationUnit), so we need a proxy for // that behavior here. auto method_val = method.get(); classTypePtr->addMethod(method_val); registerCustomClassMethod(std::move(method)); return method_val; } }; template c10::IValue make_custom_class(CtorArgs&&... args) { auto userClassInstance = c10::make_intrusive(std::forward(args)...); return c10::IValue(std::move(userClassInstance)); } // Alternative api for creating a torchbind class over torch::class_ this api is // preffered to prevent size regressions on Edge usecases. Must be used in // conjunction with TORCH_SELECTIVE_CLASS macro aka // selective_class("foo_namespace", TORCH_SELECTIVE_CLASS("foo")) template inline class_ selective_class_( const std::string& namespace_name, detail::SelectiveStr className) { auto class_name = std::string(className.operator const char*()); return torch::class_(namespace_name, class_name); } template inline detail::ClassNotSelected selective_class_( const std::string&, detail::SelectiveStr) { return detail::ClassNotSelected(); } // jit namespace for backward-compatibility // We previously defined everything in torch::jit but moved it out to // better reflect that these features are not limited only to TorchScript namespace jit { using ::torch::class_; using ::torch::getCustomClass; using ::torch::init; using ::torch::isCustomClass; } // namespace jit template inline class_ Library::class_(const std::string& className) { TORCH_CHECK( kind_ == DEF || kind_ == FRAGMENT, "class_(\"", className, "\"): Cannot define a class inside of a TORCH_LIBRARY_IMPL block. " "All class_()s should be placed in the (unique) TORCH_LIBRARY block for their namespace. " "(Error occurred at ", file_, ":", line_, ")"); TORCH_INTERNAL_ASSERT(ns_.has_value(), file_, ":", line_); return torch::class_(*ns_, className); } const std::unordered_set getAllCustomClassesNames(); template inline class_ Library::class_(detail::SelectiveStr className) { auto class_name = std::string(className.operator const char*()); TORCH_CHECK( kind_ == DEF || kind_ == FRAGMENT, "class_(\"", class_name, "\"): Cannot define a class inside of a TORCH_LIBRARY_IMPL block. " "All class_()s should be placed in the (unique) TORCH_LIBRARY block for their namespace. " "(Error occurred at ", file_, ":", line_, ")"); TORCH_INTERNAL_ASSERT(ns_.has_value(), file_, ":", line_); return torch::class_(*ns_, class_name); } template inline detail::ClassNotSelected Library::class_(detail::SelectiveStr) { return detail::ClassNotSelected(); } } // namespace torch