Class AdaptiveLogSoftmaxWithLossImpl#
Defined in File adaptive.h
Page Contents
Inheritance Relationships#
Base Type#
public torch::nn::Cloneable< AdaptiveLogSoftmaxWithLossImpl >
(Template Class Cloneable)
Class Documentation#
-
class AdaptiveLogSoftmaxWithLossImpl : public torch::nn::Cloneable<AdaptiveLogSoftmaxWithLossImpl>#
Efficient softmax approximation as described in
Efficient softmax approximation for GPUs
_ by Edouard Grave, Armand Joulin, Moustapha Cissé, David Grangier, and Hervé Jégou.See https://pytorch.org/docs/main/nn.html#torch.nn.AdaptiveLogSoftmaxWithLoss to learn about the exact behavior of this module.
See the documentation for
torch::nn::AdaptiveLogSoftmaxWithLossOptions
class to learn what constructor arguments are supported for this module.Example:
AdaptiveLogSoftmaxWithLoss model(AdaptiveLogSoftmaxWithLossOptions(8, 10, {4, 8}).div_value(2.).head_bias(true));
Public Functions
-
inline AdaptiveLogSoftmaxWithLossImpl(int64_t in_features, int64_t n_classes, std::vector<int64_t> cutoffs)#
-
explicit AdaptiveLogSoftmaxWithLossImpl(AdaptiveLogSoftmaxWithLossOptions options_)#
-
virtual void reset() override#
reset()
must perform initialization of all members with reference semantics, most importantly parameters, buffers and submodules.
-
void reset_parameters()#
-
virtual void pretty_print(std::ostream &stream) const override#
Pretty prints the
AdaptiveLogSoftmaxWithLoss
module into the givenstream
.
-
Tensor _get_full_log_prob(const Tensor &input, const Tensor &head_output)#
Given input tensor, and output of
head
, computes the log of the full distribution.
-
Tensor log_prob(const Tensor &input)#
Computes log probabilities for all n_classes.
-
Tensor predict(const Tensor &input)#
This is equivalent to
log_pob(input).argmax(1)
but is more efficient in some cases.
Public Members
-
AdaptiveLogSoftmaxWithLossOptions options#
The options with which this
Module
was constructed.
-
std::vector<int64_t> cutoffs#
Cutoffs used to assign targets to their buckets.
It should be an ordered Sequence of integers sorted in the increasing order
-
int64_t shortlist_size#
-
int64_t n_clusters#
Number of clusters.
-
int64_t head_size#
Output size of head classifier.
-
ModuleList tail#
-
inline AdaptiveLogSoftmaxWithLossImpl(int64_t in_features, int64_t n_classes, std::vector<int64_t> cutoffs)#