Shortcuts

generate

torchtune.generation.generate(model: TransformerDecoder, prompt: Tensor, *, max_generated_tokens: int, pad_id: int = 0, temperature: float = 1.0, top_k: Optional[int] = None, stop_tokens: Optional[list[int]] = None, rng: Optional[Generator] = None, compiled_generate_next_token: Optional[Callable] = None) tuple[torch.Tensor, torch.Tensor][source]

Generates tokens from a model conditioned on a prompt, and also returns logits for the generations.

Parameters:
  • model (TransformerDecoder) – model used for generation

  • prompt (torch.Tensor) – tensor with the token IDs associated with the given prompt, with shape either [seq_length] or [bsz x seq_length].

  • max_generated_tokens (int) – number of tokens to be generated

  • pad_id (int) – token ID to use for padding, default 0.

  • temperature (float) – value to scale the predicted logits by, default 1.0.

  • top_k (Optional[int]) – If specified, we prune the sampling to only token ids within the top_k probabilities, default None.

  • stop_tokens (Optional[list[int]]) – If specified, generation is stopped when any of these tokens are generated, default None.

  • rng (Optional[torch.Generator]) – random number generator, default None.

  • compiled_generate_next_token (Optional[Callable]) – This argument is typically a reference to a compiled version of the generate_next_token() function. During autoregressive decoding, this function is called instead of the default generate_next_token() in order to accelerate generation. generate_next_token() will still be used for the first token generation - or “pre-fill” pass. Default is None.

Note

This function has only been tested with decoder-only models.

Examples

>>> import torch
>>> from torchtune.models.llama3 import llama3_tokenizer
>>> from torchtune.models.llama3 import llama3_8b
>>> from torchtune.generation import generate
>>> from torchtune.training.checkpointing import FullModelHFCheckpointer
>>> from torchtune.data import Message
>>> model = llama3_8b().cuda()
>>> checkpointer = FullModelHFCheckpointer(
...     checkpoint_dir="/tmp/Meta-Llama-3-8B-Instruct",
...     checkpoint_files=[
...         "model-00001-of-00004.safetensors",
...         "model-00002-of-00004.safetensors",
...         "model-00003-of-00004.safetensors",
...         "model-00004-of-00004.safetensors",
...     ],
...     model_type="LLAMA3",
...     output_dir="/tmp/torchtune/llama3_8b",
... )
>>> checkpoint = checkpointer.load_checkpoint()
>>> model.load_state_dict(checkpoint["model"])
>>> tokenizer = llama3_tokenizer("/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model")
>>> messages = [
...     Message(role="assistant", content="Hi my name is"),
... ]
>>> prompt = tokenizer({"messages": messages}, inference=True)
>>> output, logits = generate(model, torch.tensor(prompt["tokens"], device='cuda'), max_generated_tokens=100, pad_id=0)
>>> print(tokenizer.decode(output[0].tolist()))
>>> Hi my name is Marley. Nice to meet you, Marley! How are you doing today?... [truncated]
Returns:

tuple of two tensors:
  • tokens (torch.Tensor): tensor with the generated tokens,

    with shape [bsz x seq_len + num_generated_tokens] where num_generated_tokens may be less than max_generated_tokens if stop_tokens are provided.

  • logits (torch.Tensor): tensor with the logits associated with the generated tokens,

    with shape [bsz x num_generated_tokens x vocab_size].

Return type:

tuple[torch.Tensor, torch.Tensor]

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources