Rate this Page

VLAWrapperBase#

class torchrl.modules.vla.VLAWrapperBase(*args, **kwargs)[source]#

Base class for Vision-Language-Action policies.

A VLA policy maps multimodal robot observations – one or more camera images, optional proprioceptive state, and a natural-language instruction – to a short action chunk. This base owns the TensorDict key contract and the forward() / get_dist() orchestration; concrete policies only implement the prediction hooks _predict_chunk() (continuous head) and _predict_logits() (discrete-token head).

Two action heads are supported via action_head:

  • "continuous": forward() writes a continuous action chunk of shape [*B, chunk_size, action_dim] under action_chunk;

  • "tokens": forward() writes discrete action tokens [*B, chunk_size, action_dim] under action_tokens and their per-sample (sequence-level, summed over the chunk) log-probabilities under log_probs; get_dist() returns the token distribution for log-prob/entropy-based RL fine-tuning.

Keys are configurable through set_keys(). The wrapper is a TensorDictModuleBase, so it composes with the standard collectors, losses and transforms.

Keyword Arguments:
  • action_dim (int) – the dimensionality of a single action.

  • chunk_size (int) – the action-chunk horizon H.

  • action_head (str) – "continuous" (default) or "tokens".

  • vocab_size (int, optional) – number of action-token bins per dimension (required for the "tokens" head).

  • use_state (bool) – whether to read the proprioceptive state. Defaults to True.

  • mode (str) – "greedy" (default, argmax) or "sample" token sampling for the "tokens" head (ignored by the continuous head).

Note

This base deliberately does not inherit from the text-generation LLMWrapperBase: a VLA policy emits robot actions, not text, so it carries only the small multimodal-to-action contract.

See also

TinyVLA (reference policy).

forward(tensordict: TensorDictBase) TensorDictBase[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

get_dist(tensordict: TensorDictBase) Independent[source]#

Return the action-token distribution.

Only defined for the "tokens" action head: a Categorical over the vocabulary, wrapped in Independent over the (chunk_size, action_dim) token dims, so log_prob returns one sequence-level log-probability per sample. This is the contract PPO-style objectives expect: token RL fine-tuning works directly with ClipPPOLoss (pass critic_network=None, entropy_bonus=False and remap the keys via set_keys).

set_keys(**kwargs) VLAWrapperBase[source]#

Set the tensordict key names used by the policy (see _AcceptedKeys).