TinyVLA#
- class torchrl.modules.vla.TinyVLA(*args, **kwargs)[source]#
A tiny, dependency-free reference VLA policy for CI and tutorials.
TinyVLAfuses a small convolutional image encoder, an optional proprioceptive-state MLP, and a hashed language-instruction embedding into a trunk that feeds either a continuous action-chunk head or a discrete action-token head (seeVLAWrapperBase). It is intentionally small and CPU-friendly – a stand-in to exercise the VLA data pipeline, losses and collectors end-to-end, not a competitive policy.The language instruction is embedded by hashing the instruction string to an embedding-table index (a deterministic, tokenizer-free stand-in), so the policy is genuinely language-conditioned without any external dependency.
Note
TinyVLAexpects observations with a single leading batch dimension (imageshaped[B, C, H, W]). When training on chunked windows, flatten the time dimension into the batch first.- Keyword Arguments:
action_dim (int) – the dimensionality of a single action.
chunk_size (int) – the action-chunk horizon
H.action_head (str) –
"continuous"(default) or"tokens".vocab_size (int) – action-token bins per dimension (token head). Defaults to
256.use_state (bool) – whether to read the proprioceptive state. Defaults to
True.hidden_dim (int) – width of the fused trunk. Defaults to
128.text_vocab (int) – size of the hashed instruction embedding table. Defaults to
256.text_dim (int) – instruction-embedding dimension. Defaults to
32.mode (str) –
"greedy"or"sample"(token head). Defaults to"greedy".device (DEVICE_TYPING, optional) – device to move the parameters to.
Examples
>>> import torch >>> from tensordict import NonTensorStack, TensorDict >>> from torchrl.modules.vla import TinyVLA >>> policy = TinyVLA(action_dim=7, chunk_size=4) >>> td = TensorDict( ... { ... "observation": { ... "image": torch.zeros(2, 3, 16, 16, dtype=torch.uint8), ... "state": torch.zeros(2, 5), ... }, ... "language_instruction": NonTensorStack("pick", "place"), ... }, ... batch_size=[2], ... ) >>> policy(td)["action_chunk"].shape torch.Size([2, 4, 7])