MacroPrimitiveTransform#
- class torchrl.envs.transforms.MacroPrimitiveTransform(*args: Any, execute: bool = False, multi_action_dim: int = 1, stack_rewards: bool = True, stack_observations: bool = False, **kwargs: Any)[source]#
Expand a high-level macro action into a low-level action sequence.
The base transform is deliberately agnostic to robots, grippers and MuJoCo models. Its inverse-action path reads one macro action from
action_key, resolves a(start, target)pair of low-level actions, linearly interpolates between them overmacro_steps(plussettle_stepsheld repeats), and writes the resulting(..., T, action_dim)sequence back underaction_key. Whenexecute=Truethe constructor returnsCompose(MultiAction(...), self)so the sequence is executed by the parent environment in a single high-level step.The policy-facing action accepted under
action_keymay be:a
MacroAction/TargetMacroAction(or a plainTensorDictwith the samemode/target/steps/settle_stepsschema); ora raw tensor, treated as a direct low-level action target (
MOVE).
Domain specializations override three hooks rather than configuring adapter, solver and library objects:
_resolve()– map a macro action to(start, target, steps, settle_steps)low-level tensors;current_action()– read the low-level action used as the interpolation start (defaults to zeros or a tensor already ataction_key);transform_input_spec()– advertise the policy-facing action spec.
- Parameters:
action_key – low-level action key consumed by the inner environment and also the key carrying the macro action on the way in.
macro_steps – number of interpolated low-level actions per primitive.
settle_steps – number of repeated final actions appended after each primitive.
action_dim – low-level action dimension. Required when it cannot be inferred from specs or from the macro action target.
execute – if
True, returnCompose(MultiAction(...), transform)so emitted action sequences are executed by the parent environment.multi_action_dim – stack dimension consumed by
MultiActionwhenexecute=True.stack_rewards – whether
MultiActionreturns each low-level reward.stack_observations – whether
MultiActionreturns each low-level observation.
Examples
>>> import torch >>> from tensordict import TensorDict >>> from torchrl.envs.transforms import MacroPrimitiveTransform >>> td = TensorDict({"action": torch.ones(1, 3)}, batch_size=[1]) >>> transform = MacroPrimitiveTransform(macro_steps=2, action_dim=3) >>> transform.inv(td)["action"].shape torch.Size([1, 2, 3])
- action_sequence(tensordict: TensorDictBase, mode: int | IntEnum | None = None, *, target: Tensor | None = None, target_qpos: Tensor | None = None, steps: int | None = None, settle_steps: int | None = None) Tensor[source]#
Expand a macro action into its low-level sequence without executing.
When
mode/targetare given, a primitive is built first; otherwisetensordictis expected to already carry a macro action underaction_key.
- current_action(tensordict: TensorDictBase, batch_shape: Size, device: device, dtype: dtype, action_dim: int) Tensor[source]#
Return the low-level action used as the interpolation start.
The base implementation starts every macro from the zero action: in the inverse path
action_keycarries the incoming macro action (the target), so it must not be read back here as the start. Subclasses that can read the controlled state from observations (e.g. joint positions) override this hook.
- make_primitive(tensordict: TensorDictBase, mode: int | IntEnum = MacroPrimitive.MOVE, *, target: Tensor | None = None, target_qpos: Tensor | None = None, steps: int | None = None, settle_steps: int | None = None) TensorDictBase[source]#
Return a copy of
tensordictcarrying one macro action.This is a small scripting helper: it builds a
TargetMacroActionand stores it underaction_keyso the result can be passed toaction_sequence()or executed.
- primitive_enum#
alias of
MacroPrimitive
- transform_input_spec(input_spec: Composite) Composite[source]#
Transforms the input spec such that the resulting spec matches transform mapping.
- Parameters:
input_spec (TensorSpec) – spec before the transform
- Returns:
expected spec after the transform