Rate this Page

LinearCrossEntropyLoss#

class torch.nn.LinearCrossEntropyLoss(in_features, num_classes, *, out_features=(), bias=False, device=None, dtype=None, reduction='mean', weight=None, ignore_index=None, label_smoothing=0.0, options=None)[source]#

This criterion computes the cross entropy loss between input, linearly transformed to logits, and target.

See CrossEntropyLoss for the definition of cross entropy loss.

Parameters:
  • in_features (int) – Size of each input sample.

  • num_classes (int) – Number of classes, CC.

  • out_features (tuple[int], optional) – specifies dimensions (d1,d2,...,dK)(d_1, d_2, ..., d_K) for K-dimensional loss. Default: ().

  • bias (bool, optional) – If True, the internal Linear adds a learnable bias to the logits. Logical shape is (C, *out_features); storage is flat (self.linear.bias.shape == (C * prod(out_features),)) for the same reason as self.linear.weight – reshaping happens in forward() before passing through to linear_cross_entropy() as linear_bias. With options != None, K-dimensional bias (out_features != ()) falls back to the reference implementation with a warning; the chunked path supports only (C,)-shaped bias. Default: False.

  • device (torch.device, optional) – the desired device of linear weight. Default: None.

  • dtype (torch.dtype, optional) – the desired dtype of linear weight. Default: None.

  • weight (Tensor, optional) – a manual rescaling weight given to each class. If given, has to be a Tensor of size C.

  • reduction (str, optional) – Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. 'none': no reduction will be applied, 'mean': the weighted mean of the output is taken, 'sum': the output will be summed. Default: 'mean'.

  • ignore_index (int, optional) – Specifies a target value that is ignored and does not contribute to the input gradient. Note that ignore_index is only applicable when the target contains class indices. Default: None. When target contains class indices, the default value is mapped to -100. Note: the default ignore_index in CrossEntropyLoss is -100 for both target types.

  • label_smoothing (float, optional) – A float in [0.0, 1.0]. Specifies the amount of smoothing when computing the loss, where 0.0 means no smoothing. The targets become a mixture of the original ground truth and a uniform distribution as described in Rethinking the Inception Architecture for Computer Vision. Default: 0.00.0.

  • options (LinearCrossEntropyOptions, optional) – Specify chunking strategy options, see LinearCrossEntropyOptions for more details. To enable reference implementation of linear_cross_entropy with chunking disabled, use options=None. Note: passing a non-None options makes the module incompatible with torch.jit.script(); see the note below.

Warning

With non-None options, the chunked path consumes its precomputed gradients in-place, so any second backward() call raises (even with retain_graph=True). Use LinearCrossEntropyOptions(allow_retain_graph=True) to allow repeated backward (one extra gradient-sized allocation per call).

Note

options=None is scriptable; an options instance is not. The chunked path does not support higher-order AD, forward-mode AD, torch.func.grad / vmap(grad(...)), or Inductor lowering; see torch.nn.functional.linear_cross_entropy().

Shape:
  • Input: Shape (infeatures)(in_features), (N,infeatures)(N, in_features).

  • Target: If containing class indices, shape ()(), (N)(N) or (N,outfeatures)(N, *out_features) where each value should be between [0,C)[0, C). The target data type is required to be long when using class indices. If containing class probabilities, the target must have shape (C)(C) or (N,C,outfeatures)(N, C, *out_features), and each value should be between [0,1][0, 1]. This means the target data type is required to be float when using class probabilities. Note that PyTorch does not strictly enforce probability constraints on the class probabilities and that it is the user’s responsibility to ensure target contains valid probability distributions (see below examples section for more details).

  • Output: If reduction is ‘none’, shape ()(), (N)(N) or (N,outfeatures)(N, *out_features) depending on the shape of the input. Otherwise, scalar.

where NN is batch size.

Examples

>>> torch.manual_seed(283)
>>> # Example of target with class indices
>>> loss = nn.LinearCrossEntropyLoss(5, 10, out_features=(4, 3))
>>> input = torch.randn(2, 5, requires_grad=True)
>>> target = torch.randint(0, 10, (2, 4, 3))
>>> output = loss(input, target)
>>> output.backward()
>>>
>>> # Example of target with class probabilities
>>> input = torch.randn(2, 5, requires_grad=True)
>>> target = torch.randn(2, 10, 4, 3).softmax(dim=1)
>>> output = loss(input, target)
>>> output.backward()
forward(input, target)[source]#

Runs the forward pass.

Return type:

Tensor