VLAWrapperBase#
- class torchrl.modules.vla.VLAWrapperBase(*args, **kwargs)[source]#
Base class for Vision-Language-Action policies.
A VLA policy maps multimodal robot observations – one or more camera images, optional proprioceptive state, and a natural-language instruction – to a short action chunk. This base owns the TensorDict key contract and the
forward()/get_dist()orchestration; concrete policies only implement the prediction hooks_predict_chunk()(continuous head) and_predict_logits()(discrete-token head).Two action heads are supported via
action_head:"continuous":forward()writes a continuous action chunk of shape[*B, chunk_size, action_dim]underaction_chunk;"tokens":forward()writes discrete action tokens[*B, chunk_size, action_dim]underaction_tokensand their per-sample (sequence-level, summed over the chunk) log-probabilities underlog_probs;get_dist()returns the token distribution for log-prob/entropy-based RL fine-tuning.
Keys are configurable through
set_keys(). The wrapper is aTensorDictModuleBase, so it composes with the standard collectors, losses and transforms.- Keyword Arguments:
action_dim (int) – the dimensionality of a single action.
chunk_size (int) – the action-chunk horizon
H.action_head (str) –
"continuous"(default) or"tokens".vocab_size (int, optional) – number of action-token bins per dimension (required for the
"tokens"head).use_state (bool) – whether to read the proprioceptive state. Defaults to
True.mode (str) –
"greedy"(default, argmax) or"sample"token sampling for the"tokens"head (ignored by the continuous head).
Note
This base deliberately does not inherit from the text-generation
LLMWrapperBase: a VLA policy emits robot actions, not text, so it carries only the small multimodal-to-action contract.See also
TinyVLA(reference policy).- forward(tensordict: TensorDictBase) TensorDictBase[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- get_dist(tensordict: TensorDictBase) Independent[source]#
Return the action-token distribution.
Only defined for the
"tokens"action head: aCategoricalover the vocabulary, wrapped inIndependentover the(chunk_size, action_dim)token dims, solog_probreturns one sequence-level log-probability per sample. This is the contract PPO-style objectives expect: token RL fine-tuning works directly withClipPPOLoss(passcritic_network=None,entropy_bonus=Falseand remap the keys viaset_keys).
- set_keys(**kwargs) VLAWrapperBase[source]#
Set the tensordict key names used by the policy (see
_AcceptedKeys).