ExcludeTransform¶
- class torchrl.envs.transforms.ExcludeTransform(*excluded_keys, inverse: bool = False)[source]¶
Excludes keys from the data.
- Parameters:
*excluded_keys (iterable of NestedKey) – The name of the keys to exclude. If the key is not present, it is simply ignored.
inverse (bool, optional) – if
True, the exclusion will occur during theinvcall. Defaults toFalse.
Examples
>>> import gymnasium >>> from torchrl.envs import GymWrapper >>> env = TransformedEnv( ... GymWrapper(gymnasium.make("Pendulum-v1")), ... ExcludeTransform("truncated") ... ) >>> env.rollout(3) TensorDict( fields={ action: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False), done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3]), device=cpu, is_shared=False), observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3]), device=cpu, is_shared=False)
- forward(next_tensordict: TensorDictBase) TensorDictBase¶
Reads the input tensordict, and for the selected keys, applies the transform.
By default, this method:
calls directly
_apply_transform().does not call
_step()or_call().
This method is not called within env.step at any point. However, is is called within
sample().Note
forwardalso works with regular keyword arguments usingdispatchto cast the args names to the keys.Examples
>>> class TransformThatMeasuresBytes(Transform): ... '''Measures the number of bytes in the tensordict, and writes it under `"bytes"`.''' ... def __init__(self): ... super().__init__(in_keys=[], out_keys=["bytes"]) ... ... def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ... bytes_in_td = tensordict.bytes() ... tensordict["bytes"] = bytes ... return tensordict >>> t = TransformThatMeasuresBytes() >>> env = env.append_transform(t) # works within envs >>> t(TensorDict(a=0)) # Works offline too.
- transform_output_spec(output_spec: Composite) Composite[source]¶
Transforms the output spec such that the resulting spec matches transform mapping.
This method should generally be left untouched. Changes should be implemented using
transform_observation_spec(),transform_reward_spec()andtransform_full_done_spec(). :param output_spec: spec before the transform :type output_spec: TensorSpec- Returns:
expected spec after the transform