Llama4VisionEncoder¶
- class torchtune.models.llama4.Llama4VisionEncoder(clip: Module, projection_head: Module)[source]¶
Vision encoder model for Llama 4. This combines a pretrained vision encoder with a learnable projection head. The projection head is converted to a fusion module and supports fusion utils.
- Parameters:
clip (nn.Module) – CLIP encoder vision model
projection_head (nn.Module) –
projection_head
that takes embeddings with dimensionencoder_dim
as input and outputs embeddings of sizedecoder_dim
. Seetorchtune.models.llama4.llama4_vision_projection_head()
as an example.
- forward(images: Tensor) Tensor [source]¶
- Parameters:
images (torch.Tensor) – Image tensor with shape [b x c x w x h]
- Returns:
- output tensor of a sequence of embeddings
[b x s x d]
where sequence length (
s
) is(num_imgs*num_tiles)+num_embeds
- output tensor of a sequence of embeddings
- Return type:
Tensor
- Notation used for tensor shapes:
b: batch size, equal to flatten(batch x images x tiles)
c: number of image channels (e.g. rgb = 3)
w: image width
h: image height
s: sequence length computed by i*t*clip_embeds_per_tile
d: embed dim