DoubleToFloat¶
- class torchrl.envs.transforms.DoubleToFloat(in_keys: Sequence[NestedKey] | None = None, out_keys: Sequence[NestedKey] | None = None, in_keys_inv: Sequence[NestedKey] | None = None, out_keys_inv: Sequence[NestedKey] | None = None)[source]¶
Casts one dtype to another for selected keys.
Depending on whether the
in_keysorin_keys_invare provided during construction, the class behaviour will change:If the keys are provided, those entries and those entries only will be transformed from
float64tofloat32entries;If the keys are not provided and the object is within an environment register of transforms, the input and output specs that have a dtype set to
float64will be used as in_keys_inv / in_keys respectively.If the keys are not provided and the object is used without an environment, the
forward/inversepass will scan through the input tensordict for all float64 values and map them to a float32 tensor. For large data structures, this can impact performance as this scanning doesn’t come for free. The keys to be transformed will not be cached. Note that, in this case, the out_keys (resp. out_keys_inv) cannot be passed as the order on which the keys are processed cannot be anticipated precisely.
- Parameters:
in_keys (sequence of NestedKey, optional) – list of double keys to be converted to float before being exposed to external objects and functions.
out_keys (sequence of NestedKey, optional) – list of destination keys. Defaults to
in_keysif not provided.in_keys_inv (sequence of NestedKey, optional) – list of float keys to be converted to double before being passed to the contained base_env or storage.
out_keys_inv (sequence of NestedKey, optional) – list of destination keys for inverse transform. Defaults to
in_keys_invif not provided.
Examples
>>> td = TensorDict( ... {'obs': torch.ones(1, dtype=torch.double), ... 'not_transformed': torch.ones(1, dtype=torch.double), ... }, []) >>> transform = DoubleToFloat(in_keys=["obs"]) >>> _ = transform(td) >>> print(td.get("obs").dtype) torch.float32 >>> print(td.get("not_transformed").dtype) torch.float64
In “automatic” mode, all float64 entries are transformed:
Examples
>>> td = TensorDict( ... {'obs': torch.ones(1, dtype=torch.double), ... 'not_transformed': torch.ones(1, dtype=torch.double), ... }, []) >>> transform = DoubleToFloat() >>> _ = transform(td) >>> print(td.get("obs").dtype) torch.float32 >>> print(td.get("not_transformed").dtype) torch.float32
The same behaviour is the rule when environments are constructedw without specifying the transform keys:
Examples
>>> class MyEnv(EnvBase): ... def __init__(self): ... super().__init__() ... self.observation_spec = CompositeSpec(obs=UnboundedContinuousTensorSpec((), dtype=torch.float64)) ... self.action_spec = UnboundedContinuousTensorSpec((), dtype=torch.float64) ... self.reward_spec = UnboundedContinuousTensorSpec((1,), dtype=torch.float64) ... self.done_spec = UnboundedContinuousTensorSpec((1,), dtype=torch.bool) ... def _reset(self, data=None): ... return TensorDict({"done": torch.zeros((1,), dtype=torch.bool), **self.observation_spec.rand()}, []) ... def _step(self, data): ... assert data["action"].dtype == torch.float64 ... reward = self.reward_spec.rand() ... done = torch.zeros((1,), dtype=torch.bool) ... obs = self.observation_spec.rand() ... assert reward.dtype == torch.float64 ... assert obs["obs"].dtype == torch.float64 ... return obs.empty().set("next", obs.update({"reward": reward, "done": done})) ... def _set_seed(self, seed): ... pass >>> env = TransformedEnv(MyEnv(), DoubleToFloat()) >>> assert env.action_spec.dtype == torch.float32 >>> assert env.observation_spec["obs"].dtype == torch.float32 >>> assert env.reward_spec.dtype == torch.float32, env.reward_spec.dtype >>> print(env.rollout(2)) TensorDict( fields={ action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False), obs: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([2]), device=cpu, is_shared=False), obs: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([2]), device=cpu, is_shared=False) >>> assert env.transform.in_keys == ["obs", "reward"] >>> assert env.transform.in_keys_inv == ["action"]