padded_collate_sft¶
- torchtune.data.padded_collate_sft(batch: list[dict[str, Any]], padding_idx: int = 0, ignore_idx: int = - 100, pad_to_multiple_of: int = 1, stack_on_new_dim: bool = False) dict[str, torch.Tensor] [source]¶
Pad a batch of sequences to the longest sequence length in the batch, and convert integer lists to tensors.
- Parameters:
batch (list[dict[str, Any]]) – A list of dictionaries containing samples, including tokens and labels.
padding_idx (int) – Padding index for input ids. Defaults to 0.
ignore_idx (int) – Padding index for labels. Defaults to -100.
pad_to_multiple_of (int) – If > 1, pad the sequence to a multiple of this number. This is useful for proper sharding with e.g. SequenceParallel.
stack_on_new_dim (bool) – If True, stack any encoder tensors on a new dimension. Default is False
- Returns:
Collated input and label tensors.
- Return type:
Example
>>> token_pairs = [ >>> {"tokens": [1, 2, 3], "labels": [4, 5, 6]}, >>> {"tokens": [7,], "labels": [10,]}, >>> ] >>> collated = padded_collate( >>> batch=token_pairs, >>> padding_idx=padding_idx, >>> ignore_idx=ignore_idx, >>> ) >>> collated["tokens"] >>> tensor([[1, 2, 3], [7, 0, 0]]) >>> collated["labels"] >>> tensor([[4, 5, 6], [10, -100, -100]])