LinearCrossEntropyLoss¶
- class torchtune.modules.loss.LinearCrossEntropyLoss(num_output_chunks: int = 8, ignore_index: int = - 100, tp_enabled: bool = False, mask_ignored_tokens: bool = True)[source]¶
Memory efficient Cross-entropy loss that incrementally computes loss for chunks of tokens by masking ignored tokens, calculating logits and then applying cross-entropy loss. Combines the linear projection with the cross-entropy calculation for further memory savings.
Linear cross entropy masks out ignored tokens before the projection layer to save memory. You therefore need to skip the final projection layer in your model and pass it to the loss instead. You can setup the loss with the model and compile it as shown below.
>>> model = Transformer(...) >>> loss = LinearCrossEntropyLoss(...) >>> loss.set_model_output(model) >>> loss.apply_compile_strategy()
- apply_compile_strategy(*args, **kwargs)[source]¶
Applies compile only to the compute_cross_entropy function. If compiling CE + chunking operation together, memory requirement is higher.
- compute_cross_entropy(hidden_chunk: Tensor, target_chunk: Tensor) Tensor [source]¶
Computes cross-entropy by masking tokens, calculating logits and then applying cross-entropy loss.
- Parameters:
hidden_chunk (torch.Tensor) – [batch_size, chunk_size, embed_dim]
target_chunk (torch.Tensor) – [batch_size, chunk_size]
- Returns:
Sum of cross-entropy loss for non-ignored tokens in the chunk
- Return type:
- Raises:
AttributeError – if called before update_model
- forward(outputs: Tensor, targets: Tensor) Tensor [source]¶
- Parameters:
outputs (torch.Tensor) – Hidden state of the model, pre projection. Shape
[bsz, seq_len, emb_dim]
targets (torch.Tensor) – Labels for the model. Shape
[bsz, seq_len]
- Returns:
loss tensor
- Return type:
- mask_inputs(hidden: Tensor, target: Tensor) tuple[torch.Tensor, torch.Tensor] [source]¶
- Parameters:
hidden (torch.Tensor) – Hidden state of the model, pre projection. Shape
[bsz*seq_len, emb_dim]
target (torch.Tensor) – Labels for the model. Shape
[bsz*seq_len]
- Returns:
returns a tuple of - The indexed hidden states - The indexed targets
- Return type:
- patch_tp_plan(tp_plan) dict [source]¶
Whether the loss function supports loss parallel. Defaults to a noop.
- set_model_output(model: Module) None [source]¶
Modify model output to match the expected input for the loss function.
- property tp_requires_loss_parallel_ctx_manager: bool¶
Whether to use the loss parallel context manager for loss parallelism. https://docs.pytorch.org/docs/stable/distributed.tensor.parallel.html#torch.distributed.tensor.parallel.loss_parallel