torch.nn.functional.linear_cross_entropy#
- torch.nn.functional.linear_cross_entropy(input, linear_weight, target, *, weight=None, reduction='mean', ignore_index=None, label_smoothing=0.0)[source]#
Compute the cross entropy loss between inputs, transformed linearly, and target.
The statement:
loss = linear_cross_entropy(input, linear_weight, target, **kwargs)
is equivalent to the following reference implementation of linear_cross_entropy:
logits = linear(input, linear_weight) loss = cross_entropy(logits, target, **kwargs)
provided that
ignore_indexis not explicitly set to None in kwargs (sincecross_entropy()does not accept None forignore_index).See
LinearandCrossEntropyLossfor details.- Parameters:
input (Tensor) – input samples.
linear_weight (Tensor) – linear weight.
target (Tensor) – Ground truth class indices or class probabilities;
weight (Tensor, optional) – a manual rescaling weight given to each class.
reduction (str, optional) – Specifies the reduction to apply to the output:
'none'|'mean'|'sum'.'none': no reduction will be applied,'mean': the sum of the output will be divided by the number of elements in the output,'sum': the output will be summed. Default:'mean'.ignore_index (int, optional) – Specifies a target value that is ignored and does not contribute to the input gradient. Note that
ignore_indexis only applicable when the target contains class indices. Default: None. When target contains class indices, the default value is mapped to -100. Note: the defaultignore_indexincross_entropyis -100 for both target types.label_smoothing (float, optional) – A float in [0.0, 1.0]. Specifies the amount of smoothing when computing the loss, where 0.0 means no smoothing. The targets become a mixture of the original ground truth and a uniform distribution as described in Rethinking the Inception Architecture for Computer Vision. Default: .
- Return type:
- Shape:
Input: or .
Linear weight: or with in the case of K-dimensional loss. Note: multi-dimensional weights (K > 0) require batched input .
Target: If containing class indices, , , or when , where each value should be between . The target data type is required to be long when using class indices. If containing class probabilities, the target must have shape , , or when , and each value should be between . This means the target data type is required to be float when using class probabilities. Note that PyTorch does not strictly enforce probability constraints on the class probabilities and that it is the user’s responsibility to ensure
targetcontains valid probability distributions.Weight: .
Output: If reduction is ‘none’, shape , or with in the case of K-dimensional loss, depending on the shape of the input. Otherwise, scalar.
where is batch size and is number of classes.