LinearCrossEntropyLoss¶
- class torchtune.modules.loss.LinearCrossEntropyLoss(num_output_chunks: int = 8, ignore_index: int = - 100)[source]¶
Memory efficient Cross-entropy loss that incrementally computes loss for chunks of tokens by masking ignored tokens, calculating logits and then applying cross-entropy loss. Combines the linear projection with the cross-entropy calculation for further memory savings.
Linear cross entropy masks out ignored tokens before the projection layer to save memory. You therefore need to skip the final projection layer in your model and pass it to the loss instead. You can setup the loss with the model and compile it as shown below.
>>> model = Transformer(...) >>> loss = LinearCrossEntropyLoss(...) >>> loss.set_model_output(model) >>> loss.apply_compile_strategy()
- apply_compile_strategy(*args, **kwargs)[source]¶
Applies compile only to the compute_cross_entropy function. If compiling CE + chunking operation together, memory requirement is higher.
- compute_cross_entropy(hidden_chunk: Tensor, target_chunk: Tensor) Tensor [source]¶
Computes cross-entropy by masking tokens, calculating logits and then applying cross-entropy loss.
- Parameters:
hidden_chunk (torch.Tensor) – [batch_size, chunk_size, embed_dim]
target_chunk (torch.Tensor) – [batch_size, chunk_size]
- Returns:
Sum of cross-entropy loss for non-ignored tokens in the chunk
- Return type:
- Raises:
AttributeError – if called before update_model
- forward(outputs: Tensor, targets: Tensor) Tensor [source]¶
- Parameters:
outputs (torch.Tensor) – Hidden state of the model, pre projection. Shape
[bsz, seq_len, emb_dim]
targets (torch.Tensor) – Labels for the model. Shape
[bsz, seq_len]
- Returns:
loss tensor
- Return type: