padded_collate_packed¶
- torchtune.data.padded_collate_packed(batch: list[dict[str, Union[torch.Tensor, list[int]]]]) dict[str, torch.Tensor] [source]¶
Collate packed sequences into a batch. Only convert the seq lens into a block mask for use with attention. Tokens, labels, and input_pos are already padded to the same length within
PackedDataset
.- Parameters:
batch (list[PACK_TYPE]) – A list of pack dictionaries containing the following keys: - tokens: input token ids - labels: label token ids - input_pos: relative position ids for each sequence in pack - seq_lens: lengths of each sample within the pack
- Returns:
Collated input, label, input_pos, mask tensors.
- Return type:
Example
>>> token_pairs = [ >>> {"tokens": [1, 2, 3, 4, 5, 6], "labels": [7, 8, 9, 10, 11, 12], >>> "input_pos": [0, 1, 2, 0, 1, 0], "seq_lens": [3, 2, 1]}, >>> {"tokens": [13, 14, 15, 16, 17, 18], "labels": [19, 20, 21, 22, 23, 24], >>> "input_pos": [0, 1, 0, 1, 0, 1], "seq_lens": [2, 2, 2]}, >>> ] >>> collated = padded_collate_packed( >>> batch=token_pairs, >>> device=device, >>> ) >>> collated["mask"] >>> tensor([ >>> [[1, 0, 0, 0, 0, 0], >>> [1, 1, 0, 0, 0, 0], >>> [1, 1, 1, 0, 0, 0], >>> [0, 0, 0, 1, 0, 0], >>> [0, 0, 0, 1, 1, 0], >>> [0, 0, 0, 0, 0, 1]], >>> [[1, 0, 0, 0, 0, 0], >>> [1, 1, 0, 0, 0, 0], >>> [0, 0, 1, 0, 0, 0], >>> [0, 0, 1, 1, 0, 0], >>> [0, 0, 0, 0, 1, 0], >>> [0, 0, 0, 0, 1, 1]])