PromptTensorDictTokenizer¶
- class torchrl.data.PromptTensorDictTokenizer(tokenizer, max_length, key='prompt', padding='max_length', truncation=True, return_tensordict=True, device=None)[source]¶
- Tokenization recipe for prompt datasets. - Returns a tokenizer function, which reads an example containing a prompt and a label and tokenizes them. - Parameters:
- tokenizer (tokenizer from transformers library) – the tokenizer to use. 
- max_length (int) – maximum length of the sequence. 
- key (str, optional) – the key where to find the text. Defaults to - "prompt".
- padding (str, optional) – type of padding. Defaults to - "max_length".
- truncation (bool, optional) – whether the sequences should be truncated to max_length. 
- return_tensordict (bool, optional) – if - True, a TensoDict is returned. Otherwise, a the original data will be returned.
- device (torch.device, optional) – the device where to store the data. This option is ignored if - return_tensordict=False.
 
 - The - __call__()method of this class will execute the following operations:- Read the - promptstring contacted with the- labelstring and tokenize them. The results will be stored in the- "input_ids"TensorDict entry.
- Write a - "prompt_rindex"entry with the index of the last valid token from the prompt.
- Write a - "valid_sample"which identifies which entry in the tensordict has eough toknens to meet the- max_lengthcriterion.
- Return a - tensordict.TensorDictinstance with tokenized inputs.
 - The tensordict batch-size will match the batch-size of the input. - Examples - >>> from transformers import AutoTokenizer >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") >>> tokenizer.pad_token = tokenizer.eos_token >>> example = { ... "prompt": ["This prompt is long enough to be tokenized.", "this one too!"], ... "label": ["Indeed it is.", 'It might as well be.'], ... } >>> fn = PromptTensorDictTokenizer(tokenizer, 50) >>> print(fn(example)) TensorDict( fields={ attention_mask: Tensor(shape=torch.Size([2, 50]), device=cpu, dtype=torch.int64, is_shared=False), input_ids: Tensor(shape=torch.Size([2, 50]), device=cpu, dtype=torch.int64, is_shared=False), prompt_rindex: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False), valid_sample: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([2]), device=None, is_shared=False)