Shortcuts

LinearCrossEntropyLoss

class torchtune.modules.loss.LinearCrossEntropyLoss(num_output_chunks: int = 8, ignore_index: int = - 100)[source]

Memory efficient Cross-entropy loss that incrementally computes loss for chunks of tokens by masking ignored tokens, calculating logits and then applying cross-entropy loss. Combines the linear projection with the cross-entropy calculation for further memory savings.

Linear cross entropy masks out ignored tokens before the projection layer to save memory. You therefore need to skip the final projection layer in your model and pass it to the loss instead. You can setup the loss with the model and compile it as shown below.

>>> model = Transformer(...)
>>> loss = LinearCrossEntropyLoss(...)
>>> loss.set_model_output(model)
>>> loss.apply_compile_strategy()
apply_compile_strategy(*args, **kwargs)[source]

Applies compile only to the compute_cross_entropy function. If compiling CE + chunking operation together, memory requirement is higher.

compute_cross_entropy(hidden_chunk: Tensor, target_chunk: Tensor) Tensor[source]

Computes cross-entropy by masking tokens, calculating logits and then applying cross-entropy loss.

Parameters:
  • hidden_chunk (torch.Tensor) – [batch_size, chunk_size, embed_dim]

  • target_chunk (torch.Tensor) – [batch_size, chunk_size]

Returns:

Sum of cross-entropy loss for non-ignored tokens in the chunk

Return type:

torch.Tensor

Raises:

AttributeError – if called before update_model

forward(outputs: Tensor, targets: Tensor) Tensor[source]
Parameters:
  • outputs (torch.Tensor) – Hidden state of the model, pre projection. Shape [bsz, seq_len, emb_dim]

  • targets (torch.Tensor) – Labels for the model. Shape [bsz, seq_len]

Returns:

loss tensor

Return type:

torch.Tensor

set_model_output(model: Module) None[source]

Modify model output to match the expected input for the loss function.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources