padded_collate¶
- torchtune.data.padded_collate(batch: list[dict[str, list[int]]], *, pad_direction: str, keys_to_pad: list[str], padding_idx: Union[int, dict[str, int]], pad_to_multiple_of: int = 1, stack_on_new_dim: bool = False)[source]¶
A generic padding collation function which pads
keys_to_pad
entries in a batch of sequences from the givenpad_direction
to the maximum sequence length for each entry in the batch.Note
This function assumes all batch elements which are not in
keys_to_pad
do not require any collation (see example below).- Parameters:
batch (list[dict[str, list[int]]]) – A list of dictionaries containing inputs.
pad_direction (str) – whether to pad entries from the left, or right. If
pad_direction="right"
, we usetorch.nn.utils.rnn.pad_sequence()
, otherwise ifpad_direction="left"
, we usetorchtune.data.left_pad_sequence()
.keys_to_pad (list[str]) – Batch element keys to apply padding to. Should be a subset of keys in the batch.
padding_idx (Union[int, dict[str, int]]) – Either a single integer padding value to apply to all
keys_to_pad
elements, or a mapping with keys identical tokeys_to_pad
with per-key padding values.pad_to_multiple_of (int) – If > 1, pad the sequence to a multiple of this number.
stack_on_new_dim (bool) – If True, stack any encoder tensors on a new dimension. Default is False
- Returns:
The padded tensor of input ids with shape
[batch_size, max_seq_len]
.- Return type:
- Raises:
ValueError – If
pad_direction
is not one of “left” or “right”, or ifkeys_to_pad
is empty, or is not a list, or ifkeys_to_pad
is not a subset of keys in the batch, or ifpadding_idx
is provided as a dictionary, but the keys are not identical tokeys_to_pad
ifpad_direction
is “left” andpad_to_multiple_of
is > 1
Example
>>> a = [1, 2, 3] >>> b = [4, 5, 6, 7] >>> c = [8, 9, 10, 11, 12] >>> batch = [ >>> {"tokens": a, "labels": 1}, >>> {"tokens": b, "labels": 3}, >>> {"tokens": c, "labels": 0}, >>> ] >>> padded_collate( >>> batch, >>> pad_direction="left", >>> keys_to_pad=["tokens"], >>> padding_idx=-10 >>> ) { 'labels': tensor([1, 3, 0]), 'tokens': tensor([[-10, -10, 1, 2, 3], [-10, 4, 5, 6, 7], [ 8, 9, 10, 11, 12]]) }