TerminateTransform#
- class torchrl.envs.transforms.TerminateTransform(stop: Callable[[TensorDictBase], Any], *, write_done: bool = True)[source]#
Terminate a rollout when a user-supplied predicate becomes true.
After each environment step,
stop(next_tensordict)is evaluated and its boolean result is OR-ed into the environment’sterminated(and, by default,done) entries. Combined withrollout(..., break_when_any_done=True)(the default), this ends the rollout as soon as the goal condition is reached – without writing a bespoke stepping loop. It is the natural companion of therollout()actionskeyword for scripted, goal-terminated replays.- Parameters:
stop (callable) – a callable taking the post-step (
"next") TensorDict and returning a boolean scalar or a boolean tensor broadcastable to the environment’s done entries.- Keyword Arguments:
write_done (bool, optional) – if
True(default), also OR the flag into thedoneentries sobreak_when_any_donehalts the rollout. Set toFalseto write onlyterminatedentries.
Examples
>>> import torch >>> from torchrl.envs import GymEnv, TransformedEnv >>> from torchrl.envs.transforms import TerminateTransform >>> env = TransformedEnv( ... GymEnv("Pendulum-v1"), ... TerminateTransform(lambda td: td["observation"][..., 0] > 0.99), ... ) >>> rollout = env.rollout(200, break_when_any_done=True)
- forward(tensordict: TensorDictBase) TensorDictBase[source]#
Reads the input tensordict, and for the selected keys, applies the transform.
By default, this method:
calls directly
_apply_transform().does not call
_step()or_call().
This method is not called within env.step at any point. However, is is called within
sample().Note
forwardalso works with regular keyword arguments usingdispatchto cast the args names to the keys.Examples
>>> class TransformThatMeasuresBytes(Transform): ... '''Measures the number of bytes in the tensordict, and writes it under `"bytes"`.''' ... def __init__(self): ... super().__init__(in_keys=[], out_keys=["bytes"]) ... ... def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ... bytes_in_td = tensordict.bytes() ... tensordict["bytes"] = bytes ... return tensordict >>> t = TransformThatMeasuresBytes() >>> env = env.append_transform(t) # works within envs >>> t(TensorDict(a=0)) # Works offline too.