DeviceCastTransform¶
- class torchrl.envs.transforms.DeviceCastTransform(device, orig_device=None, *, in_keys=None, out_keys=None, in_keys_inv=None, out_keys_inv=None)[source]¶
Moves data from one device to another.
- Parameters:
device (torch.device or equivalent) – the destination device (outside the environment or buffer).
orig_device (torch.device or equivalent) – the origin device (inside the environment or buffer). If not specified and a parent environment exists, it it retrieved from it. In all other cases, it remains unspecified.
- Keyword Arguments:
in_keys (list of NestedKey) – the list of entries to map to a different device. Defaults to
None.out_keys (list of NestedKey) – the output names of the entries mapped onto a device. Defaults to the values of
in_keys.in_keys_inv (list of NestedKey) – the list of entries to map to a different device.
in_keys_invare the names expected by the base environment. Defaults toNone.out_keys_inv (list of NestedKey) – the output names of the entries mapped onto a device.
out_keys_invare the names of the keys as seen from outside the transformed env. Defaults to the values ofin_keys_inv.
Examples
>>> td = TensorDict( ... {'obs': torch.ones(1, dtype=torch.double), ... }, [], device="cpu:0") >>> transform = DeviceCastTransform(device=torch.device("cpu:2")) >>> td = transform(td) >>> print(td.device) cpu:2
- forward(tensordict: TensorDictBase = None) TensorDictBase[source]¶
Reads the input tensordict, and for the selected keys, applies the transform.
By default, this method:
calls directly
_apply_transform().does not call
_step()or_call().
This method is not called within env.step at any point. However, is is called within
sample().Note
forwardalso works with regular keyword arguments usingdispatchto cast the args names to the keys.Examples
>>> class TransformThatMeasuresBytes(Transform): ... '''Measures the number of bytes in the tensordict, and writes it under `"bytes"`.''' ... def __init__(self): ... super().__init__(in_keys=[], out_keys=["bytes"]) ... ... def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ... bytes_in_td = tensordict.bytes() ... tensordict["bytes"] = bytes ... return tensordict >>> t = TransformThatMeasuresBytes() >>> env = env.append_transform(t) # works within envs >>> t(TensorDict(a=0)) # Works offline too.
- transform_action_spec(full_action_spec: Composite) Composite[source]¶
Transforms the action spec such that the resulting spec matches transform mapping.
- Parameters:
action_spec (TensorSpec) – spec before the transform
- Returns:
expected spec after the transform
- transform_done_spec(full_done_spec: Composite) Composite[source]¶
Transforms the done spec such that the resulting spec matches transform mapping.
- Parameters:
done_spec (TensorSpec) – spec before the transform
- Returns:
expected spec after the transform
- transform_input_spec(input_spec: Composite) Composite[source]¶
Transforms the input spec such that the resulting spec matches transform mapping.
- Parameters:
input_spec (TensorSpec) – spec before the transform
- Returns:
expected spec after the transform
- transform_observation_spec(observation_spec: Composite) Composite[source]¶
Transforms the observation spec such that the resulting spec matches transform mapping.
- Parameters:
observation_spec (TensorSpec) – spec before the transform
- Returns:
expected spec after the transform
- transform_output_spec(output_spec: Composite) Composite[source]¶
Transforms the output spec such that the resulting spec matches transform mapping.
This method should generally be left untouched. Changes should be implemented using
transform_observation_spec(),transform_reward_spec()andtransform_full_done_spec(). :param output_spec: spec before the transform :type output_spec: TensorSpec- Returns:
expected spec after the transform
- transform_reward_spec(full_reward_spec: Composite) Composite[source]¶
Transforms the reward spec such that the resulting spec matches transform mapping.
- Parameters:
reward_spec (TensorSpec) – spec before the transform
- Returns:
expected spec after the transform
- transform_state_spec(full_state_spec: Composite) Composite[source]¶
Transforms the state spec such that the resulting spec matches transform mapping.
- Parameters:
state_spec (TensorSpec) – spec before the transform
- Returns:
expected spec after the transform