Shortcuts

Llama4VisionProjectionHead

class torchtune.models.llama4.Llama4VisionProjectionHead(output: Module, pixel_shuffle_scaling_factor: float = 0.5)[source]

Projection transformer to adapt the output of a pretrained frozen encoder (CLIP) to a pretrained decoder model. For example, nn.Sequential(CLIP(), Llama4VisionProjectionHead()).

Note: this module assumes the CLS token embedding is added at the end of the sequence.

Parameters:
  • output (nn.Module) – output layer, typically an MLP.

  • pixel_shuffle_scaling_factor (float) – scaling factor for pixel shuffle.

forward(x: Tensor) Tensor[source]
Parameters:

x (torch.Tensor) – input tensor with shape [b, e, d]

Returns:

output tensor of a sequence of embeddings [b, s, d * pixel_shuffle_factor ** 2]

Return type:

Tensor

Notation used for tensor shapes:
  • b: batch size

  • e: number of embeds per tile (e.g. CLS embed + patch embeds, etc.)

  • s: sequence length computed by t * (e - 1) // (pixel_shuffle_factor ** 2)

  • d: embed dim

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources