Shortcuts

padded_collate_packed

torchtune.data.padded_collate_packed(batch: list[dict[str, Union[torch.Tensor, list[int]]]]) dict[str, torch.Tensor][source]

Collate packed sequences into a batch. Only convert the seq lens into a block mask for use with attention. Tokens, labels, and input_pos are already padded to the same length within PackedDataset.

Parameters:

batch (list[PACK_TYPE]) – A list of pack dictionaries containing the following keys: - tokens: input token ids - labels: label token ids - input_pos: relative position ids for each sequence in pack - seq_lens: lengths of each sample within the pack

Returns:

Collated input, label, input_pos, mask tensors.

Return type:

dict[str, torch.Tensor]

Example

>>> token_pairs = [
>>>    {"tokens": [1, 2, 3, 4, 5, 6], "labels": [7, 8, 9, 10, 11, 12],
>>>     "input_pos": [0, 1, 2, 0, 1, 0], "seq_lens": [3, 2, 1]},
>>>    {"tokens": [13, 14, 15, 16, 17, 18], "labels": [19, 20, 21, 22, 23, 24],
>>>     "input_pos": [0, 1, 0, 1, 0, 1], "seq_lens": [2, 2, 2]},
>>> ]
>>> collated = padded_collate_packed(
>>>    batch=token_pairs,
>>>    device=device,
>>> )
>>> collated["mask"]
>>> tensor([
>>> [[1, 0, 0, 0, 0, 0],
>>>  [1, 1, 0, 0, 0, 0],
>>>  [1, 1, 1, 0, 0, 0],
>>>  [0, 0, 0, 1, 0, 0],
>>>  [0, 0, 0, 1, 1, 0],
>>>  [0, 0, 0, 0, 0, 1]],
>>> [[1, 0, 0, 0, 0, 0],
>>>  [1, 1, 0, 0, 0, 0],
>>>  [0, 0, 1, 0, 0, 0],
>>>  [0, 0, 1, 1, 0, 0],
>>>  [0, 0, 0, 0, 1, 0],
>>>  [0, 0, 0, 0, 1, 1]])

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources