Class FunctionalImpl#
Defined in File functional.h
Inheritance Relationships#
Base Type#
public torch::nn::Cloneable< FunctionalImpl >(Template Class Cloneable)
Class Documentation#
-
class FunctionalImpl : public torch::nn::Cloneable<FunctionalImpl>#
Wraps a function in a
Module.The
Functionalmodule allows wrapping an arbitrary function or function object in annn::Module. This is primarily handy for usage inSequential.Sequential sequential( Linear(3, 4), Functional(torch::relu), BatchNorm1d(3), Functional(torch::elu, /*alpha=*‍/1));
While a
Functionalmodule only accepts a singleTensoras input, it is possible for the wrapped function to accept further arguments. However, these have to be bound at construction time. For example, if you want to wraptorch::leaky_relu, which accepts aslopescalar as its second argument, with a particular value for itsslopein aFunctionalmodule, you could writeFunctional(torch::leaky_relu, /*slope=*‍/0.5)
The value of
0.5is then stored within theFunctionalobject and supplied to the function call at invocation time. Note that such bound values are evaluated eagerly and stored a single time. See the documentation of std::bind for more information on the semantics of argument binding.Attention
After passing any bound arguments, the function must accept a single tensor and return a single tensor.
Note that
Functionaloverloads the call operator (operator()) such that you can invoke it withmy_func(...).Public Types
-
using Function = std::function<Tensor(Tensor)>#
Public Functions
-
explicit FunctionalImpl(Function function)#
Constructs a
Functionalfrom a function object.
-
template<typename SomeFunction, typename ...Args, typename = std::enable_if_t<(sizeof...(Args) > 0)>>
inline explicit FunctionalImpl(SomeFunction original_function, Args&&... args)#
-
virtual void reset() override#
reset()must perform initialization of all members with reference semantics, most importantly parameters, buffers and submodules.
-
virtual void pretty_print(std::ostream &stream) const override#
Pretty prints the
Functionalmodule into the givenstream.
-
Tensor forward(Tensor input)#
Forwards the
inputtensor to the underlying (bound) function object.
-
Tensor operator()(Tensor input)#
Calls forward(input).
-
using Function = std::function<Tensor(Tensor)>#