:github_url: https://github.com/pytorch/pytorch .. _program_listing_file_torch_csrc_api_include_torch_nn_modules_dropout.h: Program Listing for File dropout.h ================================== |exhale_lsh| :ref:`Return to documentation for file ` (``torch/csrc/api/include/torch/nn/modules/dropout.h``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #pragma once #include #include #include #include #include namespace torch::nn { namespace detail { template class _DropoutNd : public torch::nn::Cloneable { public: _DropoutNd(double p) : _DropoutNd(DropoutOptions().p(p)) {} explicit _DropoutNd(const DropoutOptions& options_ = {}) : options(options_) { _DropoutNd::reset(); } void reset() override { TORCH_CHECK( options.p() >= 0. && options.p() <= 1., "dropout probability has to be between 0 and 1, but got ", options.p()); } DropoutOptions options; }; } // namespace detail // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Dropout ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ class TORCH_API DropoutImpl : public detail::_DropoutNd { public: using detail::_DropoutNd::_DropoutNd; Tensor forward(Tensor input); void pretty_print(std::ostream& stream) const override; }; TORCH_MODULE(Dropout); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Dropout2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ class TORCH_API Dropout2dImpl : public detail::_DropoutNd { public: using detail::_DropoutNd::_DropoutNd; Tensor forward(Tensor input); void pretty_print(std::ostream& stream) const override; }; TORCH_MODULE(Dropout2d); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Dropout3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ class TORCH_API Dropout3dImpl : public detail::_DropoutNd { public: using detail::_DropoutNd::_DropoutNd; Tensor forward(Tensor input); void pretty_print(std::ostream& stream) const override; }; TORCH_MODULE(Dropout3d); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AlphaDropout ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ class TORCH_API AlphaDropoutImpl : public detail::_DropoutNd { public: using detail::_DropoutNd::_DropoutNd; Tensor forward(const Tensor& input); void pretty_print(std::ostream& stream) const override; }; TORCH_MODULE(AlphaDropout); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FeatureAlphaDropout // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ class TORCH_API FeatureAlphaDropoutImpl : public detail::_DropoutNd { public: using detail::_DropoutNd::_DropoutNd; Tensor forward(const Tensor& input); void pretty_print(std::ostream& stream) const override; }; TORCH_MODULE(FeatureAlphaDropout); } // namespace torch::nn