Llama4VisionProjectionHead¶
- class torchtune.models.llama4.Llama4VisionProjectionHead(output: Module, pixel_shuffle_scaling_factor: float = 0.5)[source]¶
Projection transformer to adapt the output of a pretrained frozen encoder (CLIP) to a pretrained decoder model. For example,
nn.Sequential(CLIP(), Llama4VisionProjectionHead())
.Note: this module assumes the CLS token embedding is added at the end of the sequence.
- Parameters:
output (nn.Module) – output layer, typically an MLP.
pixel_shuffle_scaling_factor (float) – scaling factor for pixel shuffle.
- forward(x: Tensor) Tensor [source]¶
- Parameters:
x (torch.Tensor) – input tensor with shape [b, e, d]
- Returns:
output tensor of a sequence of embeddings [b, s, d * pixel_shuffle_factor ** 2]
- Return type:
Tensor
- Notation used for tensor shapes:
b: batch size
e: number of embeds per tile (e.g. CLS embed + patch embeds, etc.)
s: sequence length computed by t * (e - 1) // (pixel_shuffle_factor ** 2)
d: embed dim