Utilities#
Additional utilities for building neural networks: parameter initialization, module cloning, type-erased containers, padding layers, and vision utilities.
Parameter Initialization#
The torch::nn::init namespace provides functions for initializing module parameters:
#include <torch/nn/init.h>
// Xavier/Glorot initialization
torch::nn::init::xavier_uniform_(linear->weight);
torch::nn::init::xavier_normal_(linear->weight);
// Kaiming/He initialization
torch::nn::init::kaiming_uniform_(conv->weight, /*a=*/0, torch::kFanIn, torch::kReLU);
torch::nn::init::kaiming_normal_(conv->weight);
// Other initializations
torch::nn::init::zeros_(linear->bias);
torch::nn::init::ones_(bn->weight);
torch::nn::init::constant_(linear->bias, 0.1);
torch::nn::init::normal_(linear->weight, /*mean=*/0, /*std=*/0.01);
torch::nn::init::uniform_(linear->weight, /*a=*/-0.1, /*b=*/0.1);
torch::nn::init::orthogonal_(rnn->weight_hh);
Cloneable#
-
template<typename Derived>
class Cloneable : public torch::nn::Module# The
clone()method in the baseModuleclass does not have knowledge of the concrete runtime type of its subclasses.Therefore,
clone()must either be called from within the subclass, or from a base class that has knowledge of the concrete type.Cloneableuses the CRTP to gain knowledge of the subclass’ static type and provide an implementation of theclone()method. We do not want to use this pattern in the base class, because then storing a module would always require templatizing it.Public Functions
-
virtual void reset() = 0#
reset()must perform initialization of all members with reference semantics, most importantly parameters, buffers and submodules.
-
inline virtual std::shared_ptr<Module> clone(const std::optional<Device> &device = std::nullopt) const override#
Performs a recursive “deep copy” of the
Module, such that all parameters and submodules in the cloned module are different from those in the original module.
-
virtual void reset() = 0#
All torch::nn modules inherit from Cloneable, enabling deep copies:
auto model = torch::nn::Linear(10, 5);
auto model_copy = std::dynamic_pointer_cast<torch::nn::LinearImpl>(model->clone());
AnyModule#
AnyModule provides type-erased storage for any module, allowing you to
store heterogeneous modules in a single container.
-
class AnyModule#
Stores a type erased
Module.The PyTorch C++ API does not impose an interface on the signature of
forward()inModulesubclasses. This gives you complete freedom to design yourforward()methods to your liking. However, this also means there is no unified base type you could store in order to callforward()polymorphically for any module. This is where theAnyModulecomes in. Instead of inheritance, it relies on type erasure for polymorphism.An
AnyModulecan store anynn::Modulesubclass that provides aforward()method. Thisforward()may accept any types and return any type. Once stored in anAnyModule, you can invoke the underlying module’sforward()by callingAnyModule::forward()with the arguments you would supply to the stored module (though see one important limitation below). Example:.. code-block:: cpp
struct GenericTrainer {
torch::nn::AnyModule module;
void train(torch::Tensor input) {module.forward(input);
}
};
GenericTrainer trainer1{torch::nn::Linear(3, 4)};
GenericTrainer trainer2{torch::nn::Conv2d(3, 4, 2)};
As
AnyModuleerases the static type of the stored module (and itsforward()method) to achieve polymorphism, type checking of arguments is moved to runtime. That is, passing an argument with an incorrect type to anAnyModulewill compile, but throw an exception at runtime:.. code-block:: cpp
torch::nn::AnyModule module(torch::nn::Linear(3, 4));
// Linear takes a tensor as input, but we are passing an integer.
// This will compile, but throw a
torch::Errorexception at runtime.module.forward(123);
.. attention::
One noteworthy limitation of
AnyModuleis that itsforward()methoddoes not support implicit conversion of argument types. For example, if
the stored module’s
forward()method accepts afloatand you callany_module.forward(3.4)(where3.4is adouble), this will throwan exception.
The return type of the
AnyModule’sforward()method is controlled via the first template argument toAnyModule::forward(). It defaults totorch::Tensor. To change it, you can writeany_module.forward<int>(), for example... code-block:: cpp
torch::nn::AnyModule module(torch::nn::Linear(3, 4));
auto output = module.forward(torch::ones({2, 3}));
struct IntModule {
int forward(int x) { return x; }};
torch::nn::AnyModule module(IntModule{});
int output = module.forward
(5);
The only other method an
AnyModuleprovides access to on the stored module isclone(). However, you may acquire a handle on the module via.ptr(), which returns ashared_ptr<nn::Module>. Further, if you know the concrete type of the stored module, you can get a concrete handle to it using.get<T>()whereTis the concrete module type... code-block:: cpp
torch::nn::AnyModule module(torch::nn::Linear(3, 4));
std::shared_ptrnn::Module ptr = module.ptr();
torch::nn::Linear linear(module.gettorch::nn::Linear());
Public Functions
Constructs an
AnyModulefrom ashared_ptrto concrete module object.
-
template<typename ModuleType, typename = torch::detail::enable_if_module_t<ModuleType>>
explicit AnyModule(ModuleType &&module)# Constructs an
AnyModulefrom a concrete module object.
-
template<typename ModuleType>
explicit AnyModule(const ModuleHolder<ModuleType> &module_holder)# Constructs an
AnyModulefrom a module holder.
-
AnyModule(AnyModule&&) = default#
Move construction and assignment is allowed, and follows the default behavior of move for
std::unique_ptr.
-
inline AnyModule clone(std::optional<Device> device = std::nullopt) const#
Creates a deep copy of an
AnyModuleif it contains a module, else an emptyAnyModuleif it is empty.
Assigns a module to the
AnyModule(to circumvent the explicit constructor).
-
template<typename ...ArgumentTypes>
AnyValue any_forward(ArgumentTypes&&... arguments)# Invokes
forward()on the contained module with the given arguments, and returns the return value as anAnyValue.Use this method when chaining
AnyModules in a loop.
-
template<typename ReturnType = torch::Tensor, typename ...ArgumentTypes>
ReturnType forward(ArgumentTypes&&... arguments)# Invokes
forward()on the contained module with the given arguments, and casts the returnedAnyValueto the suppliedReturnType(which defaults totorch::Tensor).
-
template<typename T, typename = torch::detail::enable_if_module_t<T>>
T &get()# Attempts to cast the underlying module to the given module type.
Throws an exception if the types do not match.
-
template<typename T, typename = torch::detail::enable_if_module_t<T>>
const T &get() const# Attempts to cast the underlying module to the given module type.
Throws an exception if the types do not match.
-
template<typename T, typename ContainedType = typename T::ContainedType>
T get() const# Returns the contained module in a
nn::ModuleHoldersubclass if possible (i.e.if
Thas a constructor for the underlying module type).
-
inline std::shared_ptr<Module> ptr() const#
Returns a
std::shared_ptrwhose dynamic type is that of the underlying module.
Like
ptr(), but casts the pointer to the given type.
-
inline const std::type_info &type_info() const#
Returns the
type_infoobject of the contained value.
Example:
torch::nn::AnyModule any_module(torch::nn::Linear(10, 5));
auto output = any_module.forward(input);
Functional#
Wraps a function or callable as a module, useful for inserting arbitrary
functions into a Sequential container.
-
class FunctionalImpl : public torch::nn::Cloneable<FunctionalImpl>#
Wraps a function in a
Module.The
Functionalmodule allows wrapping an arbitrary function or function object in annn::Module. This is primarily handy for usage inSequential... code-block:: cpp
Sequential sequential(
Linear(3, 4),
Functional(torch::relu),
BatchNorm1d(3),
Functional(torch::elu, /*alpha=*‍/1));
While a
Functionalmodule only accepts a singleTensoras input, it is possible for the wrapped function to accept further arguments. However, these have to be bound at construction time. For example, if you want to wraptorch::leaky_relu, which accepts aslopescalar as its second argument, with a particular value for itsslopein aFunctionalmodule, you could write.. code-block:: cpp
Functional(torch::leaky_relu, /slope=/0.5)
The value of
0.5is then stored within theFunctionalobject and supplied to the function call at invocation time. Note that such bound values are evaluated eagerly and stored a single time. See the documentation of std::bind for more information on the semantics of argument binding... attention::
After passing any bound arguments, the function must accept a single
tensor and return a single tensor.
Note that
Functionaloverloads the call operator (operator()) such that you can invoke it withmy_func(...).Public Functions
-
template<typename SomeFunction, typename ...Args, typename = std::enable_if_t<(sizeof...(Args) > 0)>>
inline explicit FunctionalImpl(SomeFunction original_function, Args&&... args)#
-
virtual void reset() override#
reset()must perform initialization of all members with reference semantics, most importantly parameters, buffers and submodules.
-
virtual void pretty_print(std::ostream &stream) const override#
Pretty prints the
Functionalmodule into the givenstream.
ModuleHolder#
-
template<typename Contained>
class ModuleHolder : private torch::detail::ModuleHolderIndicator# A
ModuleHolderis essentially a wrapper aroundstd::shared_ptr<M>whereMis annn::Modulesubclass, with convenient constructors defined for the kind of constructions we want to allow for our modules.Public Functions
-
inline ModuleHolder()#
Default constructs the contained module if if has a default constructor, else produces a static error.
NOTE: This uses the behavior of template classes in C++ that constructors (or any methods) are only compiled when actually used.
-
inline ModuleHolder(std::nullptr_t)#
Constructs the
ModuleHolderwith an empty contained value.Access to the underlying module is not permitted and will throw an exception, until a value is assigned.
-
template<typename Head, typename ...Tail, typename = std::enable_if_t<!(torch::detail::is_module_holder_of<Head, ContainedType>::value && (sizeof...(Tail) == 0))>>
inline explicit ModuleHolder(Head &&head, Tail&&... tail)# Constructs the
ModuleHolderwith a contained module, forwarding all arguments to its constructor.
Constructs the
ModuleHolderfrom a pointer to the contained type.Example:
Linear(std::make_shared<LinearImpl>(...)).
-
inline explicit operator bool() const noexcept#
Returns true if the
ModuleHoldercontains a module, or false if it isnullptr.
-
inline const std::shared_ptr<Contained> &ptr() const#
Returns a shared pointer to the underlying module.
-
template<typename ...Args>
inline auto operator()(Args&&... args) -> torch::detail::return_type_of_forward_t<Contained, Args...># Calls the
forward()method of the contained module.
-
template<typename Arg>
inline auto operator[](Arg &&arg)# Forwards to the subscript operator of the contained module.
NOTE: std::forward is qualified to prevent VS2017 emitting error C2872: ‘std’: ambiguous symbol
-
inline bool is_empty() const noexcept#
Returns true if the
ModuleHolderdoes not contain a module.
-
inline ModuleHolder()#
CosineSimilarity#
-
class CosineSimilarity : public torch::nn::ModuleHolder<CosineSimilarityImpl>#
A
ModuleHoldersubclass forCosineSimilarityImpl.See the documentation for
CosineSimilarityImplclass to learn what methods it provides, and examples of how to useCosineSimilaritywithtorch::nn::CosineSimilarityOptions. See the documentation forModuleHolderto learn about PyTorch’s module storage semantics.Public Types
-
using Impl = CosineSimilarityImpl#
-
using Impl = CosineSimilarityImpl#
PairwiseDistance#
-
class PairwiseDistance : public torch::nn::ModuleHolder<PairwiseDistanceImpl>#
A
ModuleHoldersubclass forPairwiseDistanceImpl.See the documentation for
PairwiseDistanceImplclass to learn what methods it provides, and examples of how to usePairwiseDistancewithtorch::nn::PairwiseDistanceOptions. See the documentation forModuleHolderto learn about PyTorch’s module storage semantics.Public Types
-
using Impl = PairwiseDistanceImpl#
-
using Impl = PairwiseDistanceImpl#
PackedSequence#
-
class torch::nn::utils::rnn::PackedSequence#
Holds the data and list of
batch_sizesof a packed sequence. All RNN modules accept packed sequences as inputs.
-
const Tensor &sorted_indices() const#
Returns indices used to sort sequences by length (descending).
-
PackedSequence to(torch::Device device) const#
Moves the packed sequence to the specified device.
See also: torch::nn::utils::rnn::pack_padded_sequence and
torch::nn::utils::rnn::pad_packed_sequence.
Padding Layers#
ReflectionPad1d / ReflectionPad2d / ReflectionPad3d#
-
class ReflectionPad1d : public torch::nn::ModuleHolder<ReflectionPad1dImpl>#
A
ModuleHoldersubclass forReflectionPad1dImpl.See the documentation for
ReflectionPad1dImplclass to learn what methods it provides, and examples of how to useReflectionPad1dwithtorch::nn::ReflectionPad1dOptions. See the documentation forModuleHolderto learn about PyTorch’s module storage semantics.Public Types
-
using Impl = ReflectionPad1dImpl#
-
using Impl = ReflectionPad1dImpl#
-
class ReflectionPad2d : public torch::nn::ModuleHolder<ReflectionPad2dImpl>#
A
ModuleHoldersubclass forReflectionPad2dImpl.See the documentation for
ReflectionPad2dImplclass to learn what methods it provides, and examples of how to useReflectionPad2dwithtorch::nn::ReflectionPad2dOptions. See the documentation forModuleHolderto learn about PyTorch’s module storage semantics.Public Types
-
using Impl = ReflectionPad2dImpl#
-
using Impl = ReflectionPad2dImpl#
-
class ReflectionPad3d : public torch::nn::ModuleHolder<ReflectionPad3dImpl>#
A
ModuleHoldersubclass forReflectionPad3dImpl.See the documentation for
ReflectionPad3dImplclass to learn what methods it provides, and examples of how to useReflectionPad3dwithtorch::nn::ReflectionPad3dOptions. See the documentation forModuleHolderto learn about PyTorch’s module storage semantics.Public Types
-
using Impl = ReflectionPad3dImpl#
-
using Impl = ReflectionPad3dImpl#
ReplicationPad1d / ReplicationPad2d / ReplicationPad3d#
-
class ReplicationPad1d : public torch::nn::ModuleHolder<ReplicationPad1dImpl>#
A
ModuleHoldersubclass forReplicationPad1dImpl.See the documentation for
ReplicationPad1dImplclass to learn what methods it provides, and examples of how to useReplicationPad1dwithtorch::nn::ReplicationPad1dOptions. See the documentation forModuleHolderto learn about PyTorch’s module storage semantics.Public Types
-
using Impl = ReplicationPad1dImpl#
-
using Impl = ReplicationPad1dImpl#
-
class ReplicationPad2d : public torch::nn::ModuleHolder<ReplicationPad2dImpl>#
A
ModuleHoldersubclass forReplicationPad2dImpl.See the documentation for
ReplicationPad2dImplclass to learn what methods it provides, and examples of how to useReplicationPad2dwithtorch::nn::ReplicationPad2dOptions. See the documentation forModuleHolderto learn about PyTorch’s module storage semantics.Public Types
-
using Impl = ReplicationPad2dImpl#
-
using Impl = ReplicationPad2dImpl#
-
class ReplicationPad3d : public torch::nn::ModuleHolder<ReplicationPad3dImpl>#
A
ModuleHoldersubclass forReplicationPad3dImpl.See the documentation for
ReplicationPad3dImplclass to learn what methods it provides, and examples of how to useReplicationPad3dwithtorch::nn::ReplicationPad3dOptions. See the documentation forModuleHolderto learn about PyTorch’s module storage semantics.Public Types
-
using Impl = ReplicationPad3dImpl#
-
using Impl = ReplicationPad3dImpl#
ZeroPad1d / ZeroPad2d / ZeroPad3d#
-
class ZeroPad1d : public torch::nn::ModuleHolder<ZeroPad1dImpl>#
A
ModuleHoldersubclass forZeroPad1dImpl.See the documentation for
ZeroPad1dImplclass to learn what methods it provides, and examples of how to useZeroPad1dwithtorch::nn::ZeroPad1dOptions. See the documentation forModuleHolderto learn about PyTorch’s module storage semantics.Public Types
-
using Impl = ZeroPad1dImpl#
-
using Impl = ZeroPad1dImpl#
-
class ZeroPad2d : public torch::nn::ModuleHolder<ZeroPad2dImpl>#
A
ModuleHoldersubclass forZeroPad2dImpl.See the documentation for
ZeroPad2dImplclass to learn what methods it provides, and examples of how to useZeroPad2dwithtorch::nn::ZeroPad2dOptions. See the documentation forModuleHolderto learn about PyTorch’s module storage semantics.Public Types
-
using Impl = ZeroPad2dImpl#
-
using Impl = ZeroPad2dImpl#
-
class ZeroPad3d : public torch::nn::ModuleHolder<ZeroPad3dImpl>#
A
ModuleHoldersubclass forZeroPad3dImpl.See the documentation for
ZeroPad3dImplclass to learn what methods it provides, and examples of how to useZeroPad3dwithtorch::nn::ZeroPad3dOptions. See the documentation forModuleHolderto learn about PyTorch’s module storage semantics.Public Types
-
using Impl = ZeroPad3dImpl#
-
using Impl = ZeroPad3dImpl#
ConstantPad1d / ConstantPad2d / ConstantPad3d#
-
class ConstantPad1d : public torch::nn::ModuleHolder<ConstantPad1dImpl>#
A
ModuleHoldersubclass forConstantPad1dImpl.See the documentation for
ConstantPad1dImplclass to learn what methods it provides, and examples of how to useConstantPad1dwithtorch::nn::ConstantPad1dOptions. See the documentation forModuleHolderto learn about PyTorch’s module storage semantics.Public Types
-
using Impl = ConstantPad1dImpl#
-
using Impl = ConstantPad1dImpl#
-
class ConstantPad2d : public torch::nn::ModuleHolder<ConstantPad2dImpl>#
A
ModuleHoldersubclass forConstantPad2dImpl.See the documentation for
ConstantPad2dImplclass to learn what methods it provides, and examples of how to useConstantPad2dwithtorch::nn::ConstantPad2dOptions. See the documentation forModuleHolderto learn about PyTorch’s module storage semantics.Public Types
-
using Impl = ConstantPad2dImpl#
-
using Impl = ConstantPad2dImpl#
-
class ConstantPad3d : public torch::nn::ModuleHolder<ConstantPad3dImpl>#
A
ModuleHoldersubclass forConstantPad3dImpl.See the documentation for
ConstantPad3dImplclass to learn what methods it provides, and examples of how to useConstantPad3dwithtorch::nn::ConstantPad3dOptions. See the documentation forModuleHolderto learn about PyTorch’s module storage semantics.Public Types
-
using Impl = ConstantPad3dImpl#
-
using Impl = ConstantPad3dImpl#
Vision Layers#
PixelShuffle#
-
class PixelShuffle : public torch::nn::ModuleHolder<PixelShuffleImpl>#
A
ModuleHoldersubclass forPixelShuffleImpl.See the documentation for
PixelShuffleImplclass to learn what methods it provides, and examples of how to usePixelShufflewithtorch::nn::PixelShuffleOptions. See the documentation forModuleHolderto learn about PyTorch’s module storage semantics.Public Types
-
using Impl = PixelShuffleImpl#
-
using Impl = PixelShuffleImpl#
-
struct PixelShuffleOptions#
Options for the
PixelShufflemodule.Example:
PixelShuffle model(PixelShuffleOptions(5));
Public Functions
-
inline PixelShuffleOptions(int64_t upscale_factor)#
-
inline auto upscale_factor(const int64_t &new_upscale_factor) -> decltype(*this)#
Factor to increase spatial resolution by.
-
inline auto upscale_factor(int64_t &&new_upscale_factor) -> decltype(*this)#
-
inline const int64_t &upscale_factor() const noexcept#
-
inline int64_t &upscale_factor() noexcept#
-
inline PixelShuffleOptions(int64_t upscale_factor)#
PixelUnshuffle#
-
class PixelUnshuffle : public torch::nn::ModuleHolder<PixelUnshuffleImpl>#
A
ModuleHoldersubclass forPixelUnshuffleImpl.See the documentation for
PixelUnshuffleImplclass to learn what methods it provides, and examples of how to usePixelUnshufflewithtorch::nn::PixelUnshuffleOptions. See the documentation forModuleHolderto learn about PyTorch’s module storage semantics.Public Types
-
using Impl = PixelUnshuffleImpl#
-
using Impl = PixelUnshuffleImpl#
-
struct PixelUnshuffleOptions#
Options for the
PixelUnshufflemodule.Example:
PixelUnshuffle model(PixelUnshuffleOptions(5));
Public Functions
-
inline PixelUnshuffleOptions(int64_t downscale_factor)#
-
inline auto downscale_factor(const int64_t &new_downscale_factor) -> decltype(*this)#
Factor to decrease spatial resolution by.
-
inline auto downscale_factor(int64_t &&new_downscale_factor) -> decltype(*this)#
-
inline const int64_t &downscale_factor() const noexcept#
-
inline int64_t &downscale_factor() noexcept#
-
inline PixelUnshuffleOptions(int64_t downscale_factor)#
Upsample#
-
class Upsample : public torch::nn::ModuleHolder<UpsampleImpl>#
A
ModuleHoldersubclass forUpsampleImpl.See the documentation for
UpsampleImplclass to learn what methods it provides, and examples of how to useUpsamplewithtorch::nn::UpsampleOptions. See the documentation forModuleHolderto learn about PyTorch’s module storage semantics.Public Types
-
using Impl = UpsampleImpl#
-
using Impl = UpsampleImpl#
-
struct UpsampleOptions#
Options for the
Upsamplemodule.Example:
Upsample model(UpsampleOptions().scale_factor(std::vector<double>({3})).mode(torch::kLinear).align_corners(false));
Public Functions
-
inline auto size(const std::optional<std::vector<int64_t>> &new_size) -> decltype(*this)#
output spatial sizes.
-
inline auto size(std::optional<std::vector<int64_t>> &&new_size) -> decltype(*this)#
-
inline const std::optional<std::vector<int64_t>> &size() const noexcept#
-
inline std::optional<std::vector<int64_t>> &size() noexcept#
-
inline auto scale_factor(const std::optional<std::vector<double>> &new_scale_factor) -> decltype(*this)#
multiplier for spatial size.
-
inline auto scale_factor(std::optional<std::vector<double>> &&new_scale_factor) -> decltype(*this)#
-
inline const std::optional<std::vector<double>> &scale_factor() const noexcept#
-
inline std::optional<std::vector<double>> &scale_factor() noexcept#
-
inline auto mode(const mode_t &new_mode) -> decltype(*this)#
-
inline auto mode(mode_t &&new_mode) -> decltype(*this)#
-
inline const mode_t &mode() const noexcept#
-
inline mode_t &mode() noexcept#
-
inline auto align_corners(const std::optional<bool> &new_align_corners) -> decltype(*this)#
if “True”, the corner pixels of the input and output tensors are aligned, and thus preserving the values at those pixels.
This only has effect when :attr:
modeis “linear”, “bilinear”, “bicubic”, or “trilinear”. Default: “False”
-
inline auto align_corners(std::optional<bool> &&new_align_corners) -> decltype(*this)#
-
inline const std::optional<bool> &align_corners() const noexcept#
-
inline std::optional<bool> &align_corners() noexcept#
-
inline auto size(const std::optional<std::vector<int64_t>> &new_size) -> decltype(*this)#
Fold / Unfold#
-
class Fold : public torch::nn::ModuleHolder<FoldImpl>#
A
ModuleHoldersubclass forFoldImpl.See the documentation for
FoldImplclass to learn what methods it provides, and examples of how to useFoldwithtorch::nn::FoldOptions. See the documentation forModuleHolderto learn about PyTorch’s module storage semantics.Public Types
-
using Impl = FoldImpl#
-
using Impl = FoldImpl#
-
struct FoldOptions#
Options for the
Foldmodule.Example:
Fold model(FoldOptions({8, 8}, {3, 3}).dilation(2).padding({2, 1}).stride(2));
Public Functions
-
inline FoldOptions(ExpandingArray<2> output_size, ExpandingArray<2> kernel_size)#
-
inline auto output_size(const ExpandingArray<2> &new_output_size) -> decltype(*this)#
describes the spatial shape of the large containing tensor of the sliding local blocks.
It is useful to resolve the ambiguity when multiple input shapes map to same number of sliding blocks, e.g., with stride > 0.
-
inline auto output_size(ExpandingArray<2> &&new_output_size) -> decltype(*this)#
-
inline const ExpandingArray<2> &output_size() const noexcept#
-
inline ExpandingArray<2> &output_size() noexcept#
-
inline auto kernel_size(const ExpandingArray<2> &new_kernel_size) -> decltype(*this)#
the size of the sliding blocks
-
inline auto kernel_size(ExpandingArray<2> &&new_kernel_size) -> decltype(*this)#
-
inline const ExpandingArray<2> &kernel_size() const noexcept#
-
inline ExpandingArray<2> &kernel_size() noexcept#
-
inline auto dilation(const ExpandingArray<2> &new_dilation) -> decltype(*this)#
controls the spacing between the kernel points; also known as the à trous algorithm.
-
inline auto dilation(ExpandingArray<2> &&new_dilation) -> decltype(*this)#
-
inline const ExpandingArray<2> &dilation() const noexcept#
-
inline ExpandingArray<2> &dilation() noexcept#
-
inline auto padding(const ExpandingArray<2> &new_padding) -> decltype(*this)#
controls the amount of implicit zero-paddings on both sides for padding number of points for each dimension before reshaping.
-
inline auto padding(ExpandingArray<2> &&new_padding) -> decltype(*this)#
-
inline const ExpandingArray<2> &padding() const noexcept#
-
inline ExpandingArray<2> &padding() noexcept#
-
inline auto stride(const ExpandingArray<2> &new_stride) -> decltype(*this)#
controls the stride for the sliding blocks.
-
inline auto stride(ExpandingArray<2> &&new_stride) -> decltype(*this)#
-
inline const ExpandingArray<2> &stride() const noexcept#
-
inline ExpandingArray<2> &stride() noexcept#
-
inline FoldOptions(ExpandingArray<2> output_size, ExpandingArray<2> kernel_size)#
-
class Unfold : public torch::nn::ModuleHolder<UnfoldImpl>#
A
ModuleHoldersubclass forUnfoldImpl.See the documentation for
UnfoldImplclass to learn what methods it provides, and examples of how to useUnfoldwithtorch::nn::UnfoldOptions. See the documentation forModuleHolderto learn about PyTorch’s module storage semantics.Public Types
-
using Impl = UnfoldImpl#
-
using Impl = UnfoldImpl#
-
struct UnfoldOptions#
Options for the
Unfoldmodule.Example:
Unfold model(UnfoldOptions({2, 4}).dilation(2).padding({2, 1}).stride(2));
Public Functions
-
inline UnfoldOptions(ExpandingArray<2> kernel_size)#
-
inline auto kernel_size(const ExpandingArray<2> &new_kernel_size) -> decltype(*this)#
the size of the sliding blocks
-
inline auto kernel_size(ExpandingArray<2> &&new_kernel_size) -> decltype(*this)#
-
inline const ExpandingArray<2> &kernel_size() const noexcept#
-
inline ExpandingArray<2> &kernel_size() noexcept#
-
inline auto dilation(const ExpandingArray<2> &new_dilation) -> decltype(*this)#
controls the spacing between the kernel points; also known as the à trous algorithm.
-
inline auto dilation(ExpandingArray<2> &&new_dilation) -> decltype(*this)#
-
inline const ExpandingArray<2> &dilation() const noexcept#
-
inline ExpandingArray<2> &dilation() noexcept#
-
inline auto padding(const ExpandingArray<2> &new_padding) -> decltype(*this)#
controls the amount of implicit zero-paddings on both sides for padding number of points for each dimension before reshaping.
-
inline auto padding(ExpandingArray<2> &&new_padding) -> decltype(*this)#
-
inline const ExpandingArray<2> &padding() const noexcept#
-
inline ExpandingArray<2> &padding() noexcept#
-
inline auto stride(const ExpandingArray<2> &new_stride) -> decltype(*this)#
controls the stride for the sliding blocks.
-
inline auto stride(ExpandingArray<2> &&new_stride) -> decltype(*this)#
-
inline const ExpandingArray<2> &stride() const noexcept#
-
inline ExpandingArray<2> &stride() noexcept#
-
inline UnfoldOptions(ExpandingArray<2> kernel_size)#