FlattenObservation¶
- class torchrl.envs.transforms.FlattenObservation(first_dim: int, last_dim: int, in_keys: Sequence[NestedKey] | None = None, out_keys: Sequence[NestedKey] | None = None, allow_positive_dim: bool = False)[source]¶
Flatten adjacent dimensions of a tensor.
- Parameters:
first_dim (int) – first dimension of the dimensions to flatten.
last_dim (int) – last dimension of the dimensions to flatten.
in_keys (sequence of NestedKey, optional) – the entries to flatten. If none is provided,
["pixels"]is assumed.out_keys (sequence of NestedKey, optional) – the flatten observation keys. If none is provided,
in_keysis assumed.allow_positive_dim (bool, optional) – if
True, positive dimensions are accepted.FlattenObservationwill map these to the n^th feature dimension (ie n^th dimension after batch size of parent env) of the input tensor. Defaults to False, ie. non-negative dimensions are not permitted.
- forward(tensordict: TensorDictBase) TensorDictBase¶
Reads the input tensordict, and for the selected keys, applies the transform.
For any operation that relates exclusively to the parent env (e.g. FrameSkip), modify the _step method instead.
_call()should only be overwritten if a modification of the input tensordict is needed._call()will be called byTransformedEnv.step()andTransformedEnv.reset().
- transform_observation_spec(observation_spec: TensorSpec) TensorSpec[source]¶
Transforms the observation spec such that the resulting spec matches transform mapping.
- Parameters:
observation_spec (TensorSpec) – spec before the transform
- Returns:
expected spec after the transform