Rate this Page

Class AdaptiveLogSoftmaxWithLossImpl#

Inheritance Relationships#

Base Type#

Class Documentation#

class AdaptiveLogSoftmaxWithLossImpl : public torch::nn::Cloneable<AdaptiveLogSoftmaxWithLossImpl>#

Efficient softmax approximation as described in Efficient softmax approximation for GPUs_ by Edouard Grave, Armand Joulin, Moustapha Cissé, David Grangier, and Hervé Jégou.

See https://pytorch.org/docs/main/nn.html#torch.nn.AdaptiveLogSoftmaxWithLoss to learn about the exact behavior of this module.

See the documentation for torch::nn::AdaptiveLogSoftmaxWithLossOptions class to learn what constructor arguments are supported for this module.

Example:

AdaptiveLogSoftmaxWithLoss model(AdaptiveLogSoftmaxWithLossOptions(8, 10,
{4, 8}).div_value(2.).head_bias(true));

Public Functions

inline AdaptiveLogSoftmaxWithLossImpl(int64_t in_features, int64_t n_classes, std::vector<int64_t> cutoffs)#
explicit AdaptiveLogSoftmaxWithLossImpl(AdaptiveLogSoftmaxWithLossOptions options_)#
ASMoutput forward(const Tensor &input, const Tensor &target)#
virtual void reset() override#

reset() must perform initialization of all members with reference semantics, most importantly parameters, buffers and submodules.

void reset_parameters()#
virtual void pretty_print(std::ostream &stream) const override#

Pretty prints the AdaptiveLogSoftmaxWithLoss module into the given stream.

Tensor _get_full_log_prob(const Tensor &input, const Tensor &head_output)#

Given input tensor, and output of head, computes the log of the full distribution.

Tensor log_prob(const Tensor &input)#

Computes log probabilities for all n_classes.

Tensor predict(const Tensor &input)#

This is equivalent to log_pob(input).argmax(1) but is more efficient in some cases.

Public Members

AdaptiveLogSoftmaxWithLossOptions options#

The options with which this Module was constructed.

std::vector<int64_t> cutoffs#

Cutoffs used to assign targets to their buckets.

It should be an ordered Sequence of integers sorted in the increasing order

int64_t shortlist_size#
int64_t n_clusters#

Number of clusters.

int64_t head_size#

Output size of head classifier.

Linear head = nullptr#
ModuleList tail#