NextStateReconstructor¶
- class torchrl.envs.transforms.NextStateReconstructor(keys: Sequence[NestedKey] = ('observation',), *, traj_key: NestedKey | None = ('collector', 'traj_ids'), done_key: NestedKey | None = ('next', 'done'), step_count_key: NestedKey | None = None, fill_value: float = nan, strict: bool = True)[source]¶
Re-hydrate
("next", obs)keys at sampling time by shifting along the batch.Pairs with
SyncDataCollectorconfigured withcompact_obs=True(and the analogous flag on the multi-process collectors): the collector drops the observation and state keys from the("next", ...)sub-tensordict before stacking because those values are bit-for-bit identical to the root keys att + 1within the same trajectory; this transform rebuilds them on the consumer side.Core rule. For each registered root key
kand each positioniof the flat sampled batch:if position
i + 1is in the batch and belongs to the same trajectory as positioni, writedata[("next", k)][i] = data[k][i + 1];otherwise write
data[("next", k)][i] = fill_value(NaNby default).
“Same trajectory” is decided from a trajectory id key in the sample, by default
("collector", "traj_ids")— the key thatSyncDataCollectorpopulates whentrack_traj_ids=True(the default). The semantics fall out cleanly for every common sampler:SliceSamplerwithtraj_key: positions inside a slice mirror to the next position; slice boundaries differ in trajectory id and becomeNaN.A full rollout sampled as one contiguous batch: every transition inside a trajectory is reconstructed; trajectory ends become
NaN.RandomSamplerand similar: adjacent batch positions almost never share a trajectory id, so the result is mostlyNaN. This is correct — the next observation is genuinely not available in the sampled batch — and it makes the mis-use loud rather than silent.
The trajectory-id check alone is not enough: a sampler is allowed to place two slices of the same trajectory back-to-back in one batch (e.g.
SliceSamplersampling with replacement when there are fewer trajectories than slices). In that case the two positions across the splice would share a trajectory id without being consecutive in time. The transform therefore also consults("next", "done")(if present): whendone[i]isTruethe trajectory ended at stepi, so positioni + 1is never the next step of trajectorytraj_id[i]no matter what.An additional, stricter
step_count_keycross-check is available for setups where neithertraj_idnordoneare bulletproof — see below.- Parameters:
keys (sequence of NestedKey, optional) – the root keys whose
("next", k)counterparts should be reconstructed. Defaults to("observation",). For environments with nested observation specs, pass the full leaf list, e.g.[("agents", "pos"), ("agents", "vel")].- Keyword Arguments:
traj_key (NestedKey, optional) – key carrying the trajectory id used to detect boundaries. Defaults to
("collector", "traj_ids"). Set toNoneto skip the trajectory check and treat the entire sampled batch as one trajectory (only the very last position is then filled withfill_value).done_key (NestedKey, optional) – key whose
Trueentries indicate that the trajectory terminated at positioni, so positioni + 1is not the next step. Defaults to("next", "done"). Set toNoneto disable the check.step_count_key (NestedKey, optional) – if not
None, also requiredata[step_count_key][i + 1] == data[step_count_key][i] + 1to consider positioni + 1as the canonical next step. The collector populates("collector", "step_count")only when aStepCounteris in the env transform chain. Defaults toNone.fill_value (float, optional) – value written wherever the next observation is not available. Defaults to
float("nan"). For integer-typed observation keys, NaN cannot be represented; pass an explicit integer (e.g.0).strict (bool, optional) – if
True(default) and any configured marker key (traj_key,done_key,step_count_key) is missing from the sampled batch, raise. IfFalse, silently drop that check.
Example
>>> import torch >>> from tensordict import TensorDict >>> from torchrl.data import ReplayBuffer, LazyTensorStorage >>> from torchrl.data.replay_buffers.samplers import SliceSampler >>> from torchrl.envs.transforms.rb_transforms import ( ... NextStateReconstructor, ... ) >>> rb = ReplayBuffer( ... storage=LazyTensorStorage(100), ... sampler=SliceSampler( ... slice_len=4, traj_key=("collector", "traj_ids"), ... ), ... transform=NextStateReconstructor(), ... batch_size=8, ... ) >>> # populate `rb` with a collector configured with `compact_obs=True` >>> # so that ``("next", "observation")`` is absent from storage: >>> data = TensorDict({ ... "observation": torch.arange(8, dtype=torch.float32).view(8, 1), ... ("next", "reward"): torch.zeros(8, 1), ... ("next", "done"): torch.tensor([[False]] * 7 + [[True]]), ... ("collector", "traj_ids"): torch.tensor([0, 0, 0, 0, 1, 1, 1, 1]), ... }, batch_size=[8]) >>> rb.extend(data) >>> sample = rb.sample() # ('next', 'observation') is reconstructed
- forward(tensordict: TensorDictBase) TensorDictBase[source]¶
Reads the input tensordict, and for the selected keys, applies the transform.
By default, this method:
calls directly
_apply_transform().does not call
_step()or_call().
This method is not called within env.step at any point. However, is is called within
sample().Note
forwardalso works with regular keyword arguments usingdispatchto cast the args names to the keys.Examples
>>> class TransformThatMeasuresBytes(Transform): ... '''Measures the number of bytes in the tensordict, and writes it under `"bytes"`.''' ... def __init__(self): ... super().__init__(in_keys=[], out_keys=["bytes"]) ... ... def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ... bytes_in_td = tensordict.bytes() ... tensordict["bytes"] = bytes ... return tensordict >>> t = TransformThatMeasuresBytes() >>> env = env.append_transform(t) # works within envs >>> t(TensorDict(a=0)) # Works offline too.