ActionTokenizerTransform#
- class torchrl.envs.transforms.ActionTokenizerTransform(tokenizer: ActionTokenizerBase, *, in_key: NestedKey = 'action', out_key: NestedKey = 'action_tokens')[source]#
Encode and decode actions with an
ActionTokenizerBase.A bidirectional action <-> token codec wrapping an action tokenizer (the bins live in the tokenizer; no environment is needed to construct it). Like any TorchRL transform it plugs onto a replay buffer or an environment interchangeably:
forward (
encode): maps the continuous action (or action chunk) atin_keyto discrete token ids atout_key– e.g. building the token training target for an autoregressive (RT-2 / OpenVLA-style) token VLA on the replay-buffer sample path.inverse (
decode): maps token ids atout_keyback to a continuous action atin_key– e.g. decoding the tokens a token-head policy emits, on the environment action-input path, before the base env consumes them. On a replay buffer the inverse is a no-op when the token entry is absent, so extending with raw (untokenized) data is safe; attached to an environment, missing tokens on the step path raise instead.
When attached to an environment, the policy-facing action spec is rewritten to a
Categoricalover the tokenizer’s vocabulary, so the env advertises the token interface the policy is expected to produce (the decoded continuous action is consumed by the base env internally). Using the same tokenizer instance on the replay buffer (encode) and on the env (decode) guarantees that training targets and execution share the exact same binning.- Parameters:
tokenizer (ActionTokenizerBase) – the tokenizer to apply.
- Keyword Arguments:
in_key (NestedKey) – the continuous action. Defaults to
"action".out_key (NestedKey) – the discrete token ids. Defaults to
"action_tokens".
Examples
>>> import torch >>> from tensordict import TensorDict >>> from torchrl.data.vla import UniformActionTokenizer >>> from torchrl.envs.transforms import ActionTokenizerTransform >>> tok = UniformActionTokenizer(256, low=-1.0, high=1.0) >>> t = ActionTokenizerTransform(tok) >>> td = t(TensorDict({"action": torch.tensor([[-1.0, 0.0, 1.0]])}, batch_size=[1])) >>> td["action_tokens"] tensor([[ 0, 128, 255]]) >>> # the inverse decodes tokens back to a continuous action >>> back = t.inv(TensorDict({"action_tokens": td["action_tokens"]}, batch_size=[1])) >>> back["action"].shape torch.Size([1, 3]) >>> # on a replay buffer: raw actions written through extend are stored >>> # as-is and tokenized on the sample path >>> from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer >>> rb = TensorDictReplayBuffer( ... storage=LazyTensorStorage(8), ... transform=ActionTokenizerTransform(tok), ... batch_size=2, ... ) >>> indices = rb.extend( ... TensorDict({"action": torch.rand(8, 3) * 2 - 1}, batch_size=[8]) ... ) >>> rb.sample()["action_tokens"].shape torch.Size([2, 3]) >>> # on an environment: the policy-facing action spec becomes the token >>> # interface, and emitted tokens are decoded before the base env >>> # consumes them >>> from torchrl.envs import GymEnv, TransformedEnv >>> tok_env = UniformActionTokenizer(256, low=-2.0, high=2.0) # Pendulum bounds >>> env = TransformedEnv(GymEnv("Pendulum-v1"), ActionTokenizerTransform(tok_env)) >>> env.full_action_spec["action_tokens"].shape torch.Size([1]) >>> env.rollout(2)["action_tokens"].dtype torch.int64
See also
ActionDiscretizer– the env-only discretizer that derives its bins from the environment’s boundedaction_spec(with configurable in-bin sampling strategies) so a discrete-action policy can act on a continuous env. UseActionDiscretizerwhen the binning should follow the env spec; useActionTokenizerTransformwhen the binning is owned by a tokenizer (dataset statistics, FAST/DCT-style codecs) that must be shared between offline encoding and online decoding.- forward(tensordict: TensorDictBase) TensorDictBase[source]#
Reads the input tensordict, and for the selected keys, applies the transform.
By default, this method:
calls directly
_apply_transform().does not call
_step()or_call().
This method is not called within env.step at any point. However, is is called within
sample().Note
forwardalso works with regular keyword arguments usingdispatchto cast the args names to the keys.Examples
>>> class TransformThatMeasuresBytes(Transform): ... '''Measures the number of bytes in the tensordict, and writes it under `"bytes"`.''' ... def __init__(self): ... super().__init__(in_keys=[], out_keys=["bytes"]) ... ... def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ... bytes_in_td = tensordict.bytes() ... tensordict["bytes"] = bytes ... return tensordict >>> t = TransformThatMeasuresBytes() >>> env = env.append_transform(t) # works within envs >>> t(TensorDict(a=0)) # Works offline too.
- transform_input_spec(input_spec: Composite) Composite[source]#
Transforms the input spec such that the resulting spec matches transform mapping.
- Parameters:
input_spec (TensorSpec) – spec before the transform
- Returns:
expected spec after the transform