Reward2GoTransform¶
- class torchrl.envs.transforms.Reward2GoTransform(gamma: float | torch.Tensor | None = 1.0, in_keys: collections.abc.Sequence[tensordict._nestedkey.NestedKey] | None = None, out_keys: collections.abc.Sequence[tensordict._nestedkey.NestedKey] | None = None, done_key: tensordict._nestedkey.NestedKey | None = 'done')[source]¶
Calculates the reward to go based on the episode reward and a discount factor.
As the
Reward2GoTransformis only an inverse transform thein_keyswill be directly used for thein_keys_inv. The reward-to-go can be only calculated once the episode is finished. Therefore, the transform should be applied to the replay buffer and not to the collector or within an environment.- Parameters:
gamma (
floator torch.Tensor) – the discount factor. Defaults to 1.0.in_keys (sequence of NestedKey) – the entries to rename. Defaults to
("next", "reward")if none is provided.out_keys (sequence of NestedKey) – the entries to rename. Defaults to the values of
in_keysif none is provided.done_key (NestedKey) – the done entry. Defaults to
"done".truncated_key (NestedKey) – the truncated entry. Defaults to
"truncated". If no truncated entry is found, only the"done"will be used.
Examples
>>> # Using this transform as part of a replay buffer >>> from torchrl.data import ReplayBuffer, LazyTensorStorage >>> torch.manual_seed(0) >>> r2g = Reward2GoTransform(gamma=0.99, out_keys=["reward_to_go"]) >>> rb = ReplayBuffer(storage=LazyTensorStorage(100), transform=r2g) >>> batch, timesteps = 4, 5 >>> done = torch.zeros(batch, timesteps, 1, dtype=torch.bool) >>> for i in range(batch): ... while not done[i].any(): ... done[i] = done[i].bernoulli_(0.1) >>> reward = torch.ones(batch, timesteps, 1) >>> td = TensorDict( ... {"next": {"done": done, "reward": reward}}, ... [batch, timesteps], ... ) >>> rb.extend(td) >>> sample = rb.sample(1) >>> print(sample["next", "reward"]) tensor([[[1.], [1.], [1.], [1.], [1.]]]) >>> print(sample["reward_to_go"]) tensor([[[4.9010], [3.9404], [2.9701], [1.9900], [1.0000]]])
One can also use this transform directly with a collector: make sure to append the inv method of the transform.
Examples
>>> from torchrl.envs.utils import RandomPolicy >>> from torchrl.collectors import SyncDataCollector >>> from torchrl.envs.libs.gym import GymEnv >>> t = Reward2GoTransform(gamma=0.99, out_keys=["reward_to_go"]) >>> env = GymEnv("Pendulum-v1") >>> collector = SyncDataCollector( ... env, ... RandomPolicy(env.action_spec), ... frames_per_batch=200, ... total_frames=-1, ... postproc=t.inv ... ) >>> for data in collector: ... break >>> print(data) TensorDict( fields={ action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), collector: TensorDict( fields={ traj_ids: Tensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([200]), device=cpu, is_shared=False), done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([200]), device=cpu, is_shared=False), observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), reward_to_go: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([200]), device=cpu, is_shared=False)
Using this transform as part of an env will raise an exception
Examples
>>> t = Reward2GoTransform(gamma=0.99) >>> TransformedEnv(GymEnv("Pendulum-v1"), t) # crashes
Note
In settings where multiple done entries are present, one should build a single
Reward2GoTransformfor each done-reward pair.- forward(tensordict: TensorDictBase) TensorDictBase[source]¶
Reads the input tensordict, and for the selected keys, applies the transform.
By default, this method:
calls directly
_apply_transform().does not call
_step()or_call().
This method is not called within env.step at any point. However, is is called within
sample().Note
forwardalso works with regular keyword arguments usingdispatchto cast the args names to the keys.Examples
>>> class TransformThatMeasuresBytes(Transform): ... '''Measures the number of bytes in the tensordict, and writes it under `"bytes"`.''' ... def __init__(self): ... super().__init__(in_keys=[], out_keys=["bytes"]) ... ... def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ... bytes_in_td = tensordict.bytes() ... tensordict["bytes"] = bytes ... return tensordict >>> t = TransformThatMeasuresBytes() >>> env = env.append_transform(t) # works within envs >>> t(TensorDict(a=0)) # Works offline too.