FlattenAction#
- class torchrl.envs.transforms.FlattenAction(first_dim: int = -2, last_dim: int = -1, in_keys_inv: Sequence[NestedKey] | None = None, out_keys_inv: Sequence[NestedKey] | None = None, in_keys: Sequence[NestedKey] | None = None, out_keys: Sequence[NestedKey] | None = None, allow_positive_dim: bool = False, *, action_shape: Sequence[int] | None = None)[source]#
Flatten adjacent dimensions of an action.
Mirrors
FlattenObservation, but applies to actions: the policy sees a flattened action space and the original multi-dimensional shape is restored on the inv direction before the action is passed to the base environment.On the inv direction (policy -> env), a 1-D
flattenedaction is unflattened to the original(dim_first, ..., dim_last)span of the env action. On the forward direction (env action -> flattened, used inside replay buffers andModulechains), the adjacent dims[first_dim, last_dim]are flattened.- Parameters:
first_dim (int) – first dimension to flatten. Must be negative unless
allow_positive_dimisTrue.last_dim (int) – last dimension to flatten (inclusive). Must be negative unless
allow_positive_dimisTrue.in_keys_inv (sequence of NestedKey, optional) – keys read during the
invdirection (policy -> env). Defaults to["action"]. Multiple keys are supported - the same flatten span is applied to each one, which is useful for dict-structured action spaces.out_keys_inv (sequence of NestedKey, optional) – keys written during the
invdirection. Defaults toin_keys_inv.in_keys (sequence of NestedKey, optional) – keys read during the forward direction (env action -> flattened). Defaults to
in_keys_inv.out_keys (sequence of NestedKey, optional) – keys written during the forward direction. Defaults to
in_keys.allow_positive_dim (bool, optional) – if
True, positive dimensions are accepted. Defaults toFalseso that the same transform works regardless of the parent environment’s batch size.
- Keyword Arguments:
action_shape (sequence of int, optional) – explicit pre-flatten shape of the dimensions
[first_dim, last_dim]. Useful when the transform is used outside aTransformedEnv(e.g. inside a replay buffer) and the original action shape cannot be derived from a parent env. The same span is applied to every entry ofin_keys_inv. Defaults toNone, in which case the shape is derived lazily from the parent env’s action spec.
Examples
>>> import torch >>> from torchrl.data.tensor_specs import Bounded >>> from torchrl.envs.transforms import FlattenAction, TransformedEnv >>> from torchrl.testing.mocking_classes import ContinuousActionVecMockEnv >>> base_env = ContinuousActionVecMockEnv( ... action_spec=Bounded(low=-1.0, high=1.0, shape=(3, 5)) ... ) >>> env = TransformedEnv(base_env, FlattenAction(first_dim=-2, last_dim=-1)) >>> env.action_spec.shape torch.Size([15])
- transform_action_spec(action_spec: TensorSpec) TensorSpec[source]#
Transforms the action spec such that the resulting spec matches transform mapping.
- Parameters:
action_spec (TensorSpec) – spec before the transform
- Returns:
expected spec after the transform