:github_url: https://github.com/pytorch/pytorch .. _program_listing_file_torch_csrc_api_include_torch_nn_pimpl.h: Program Listing for File pimpl.h ================================ |exhale_lsh| :ref:`Return to documentation for file ` (``torch/csrc/api/include/torch/nn/pimpl.h``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #pragma once #include #include #include #include #include #include #include #include namespace torch { namespace detail { // Dump all the template metaprogramming in this file. #include } // namespace detail namespace nn { template class ModuleHolder : torch::detail::ModuleHolderIndicator { protected: // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) std::shared_ptr impl_; public: using ContainedType = Contained; ModuleHolder() : impl_(default_construct()) { static_assert( std::is_default_constructible_v, "You are trying to default construct a module which has " "no default constructor. Use = nullptr to give it the empty state " "(e.g. `Linear linear = nullptr;` instead of `Linear linear;`)."); } /* implicit */ ModuleHolder(std::nullptr_t) : impl_(nullptr) {} template < typename Head, typename... Tail, typename = std::enable_if_t< !(torch::detail::is_module_holder_of::value && (sizeof...(Tail) == 0))>> explicit ModuleHolder(Head&& head, Tail&&... tail) : impl_(new Contained( std::forward(head), std::forward(tail)...)) {} /* implicit */ ModuleHolder(std::shared_ptr module) : impl_(std::move(module)) {} explicit operator bool() const noexcept { return !is_empty(); } Contained* operator->() { return get(); } const Contained* operator->() const { return get(); } Contained& operator*() { return *get(); } const Contained& operator*() const { return *get(); } const std::shared_ptr& ptr() const { TORCH_CHECK(!is_empty(), "Accessing empty ModuleHolder"); return impl_; } Contained* get() { TORCH_CHECK(!is_empty(), "Accessing empty ModuleHolder"); return impl_.get(); } const Contained* get() const { TORCH_CHECK(!is_empty(), "Accessing empty ModuleHolder"); return impl_.get(); } template auto operator()(Args&&... args) -> torch::detail::return_type_of_forward_t { // This will not compile if the module does not have a `forward()` method // (as expected). // NOTE: `std::forward` is qualified to prevent VS2017 emitting // error C2872: 'std': ambiguous symbol return impl_->forward(::std::forward(args)...); } template auto operator[](Arg&& arg) { return (*impl_)[::std::forward(arg)]; } bool is_empty() const noexcept { return impl_ == nullptr; } private: template std::shared_ptr default_construct() { if constexpr (std::is_default_constructible_v) { return std::make_shared(); } else { return nullptr; } } }; template std::ostream& operator<<( std::ostream& stream, const nn::ModuleHolder& module) { return stream << *module; } template serialize::OutputArchive& operator<<( serialize::OutputArchive& archive, const nn::ModuleHolder& module) { return archive << module.ptr(); } template serialize::InputArchive& operator>>( serialize::InputArchive& archive, nn::ModuleHolder& module) { return archive >> module.ptr(); } } // namespace nn } // namespace torch // Workaround for CUDA 10.2 and below not allowing attribute unused on // using declarations. #ifdef __CUDACC__ #define TORCH_UNUSED_EXCEPT_CUDA #else #define TORCH_UNUSED_EXCEPT_CUDA [[maybe_unused]] #endif #define TORCH_MODULE_IMPL(Name, ImplType) \ class Name : public torch::nn::ModuleHolder { /* NOLINT */ \ public: \ using torch::nn::ModuleHolder::ModuleHolder; \ using Impl TORCH_UNUSED_EXCEPT_CUDA = ImplType; \ } #define TORCH_MODULE(Name) TORCH_MODULE_IMPL(Name, Name##Impl)