cat_frames#
- class torchrl.envs.transforms.functional.cat_frames(tensor: Tensor, N: int, dim: int, *, padding: Literal['same', 'constant'] = 'same', padding_value: float = 0.0, time_dim: int = -1, done_mask: Tensor | None = None)[source]#
Stacks a sliding window of
Nsuccessive frames alongdim.This is the pure, stateless core of the
CatFramestransform (the PyTorchF.x/nn.Xsplit):CatFramesdelegates its offline / replay-buffer (contiguous trajectory slice) windowing to this function so that the two stay byte-for-byte identical.For every position
talongtime_dim, theNframes[t - N + 1, ..., t]are concatenated alongdim. The firstN - 1positions of a trajectory have fewer thanNreal frames; the missing frames are filled according topadding. This matches the offline behavior ofCatFrames; see the “Examples” of that class for the online (stateful, per-step) usage.It was first proposed in “Playing Atari with Deep Reinforcement Learning” (https://arxiv.org/abs/1312.5602).
- Parameters:
tensor (torch.Tensor) – the frames to stack. One of its dimensions (
time_dim) is the time axis along which the sliding window moves;dimis the (channel/feature) axis along which theNframes are concatenated.N (int) – number of successive frames to concatenate.
dim (int) – the dimension along which the frames are concatenated. Must be negative so that it is invariant to leading batch dimensions. The size of
tensoralongdimis multiplied byNin the output.
- Keyword Arguments:
padding (str, optional) – the padding method, one of
"same"or"constant". With"same"(default) the first real frame of the trajectory is repeated; with"constant"the missing frames are filled withpadding_value.padding_value (float, optional) – the value used to pad when
padding="constant". Defaults to0.time_dim (int, optional) – the dimension of
tensorthat holds the time axis. Must be negative. Defaults to-1.done_mask (torch.Tensor, optional) – an optional boolean mask flagging, for each sliding window, which of its
Npositions reach across a trajectory boundary (and must therefore be padded). Its shape is(*batch, time, N)wheretimematches the size oftensoralongtime_dim. WhenNone(default), the input is treated as a single trajectory and only the leadingN - 1start-of-sequence frames are padded.CatFramesbuilds this mask from the environmentdonesignal.
- Returns:
a tensor identical to
tensorexcept that its size alongdimis multiplied byN(the concatenated window) and its dtype / device are preserved.- Return type:
Examples
>>> import torch >>> from torchrl.envs.transforms.functional import cat_frames >>> # a single trajectory of 4 frames, each a length-2 feature vector, >>> # stacked over a window of N=3 along the feature dim (-1). >>> frames = torch.arange(8.0).view(4, 2) >>> frames tensor([[0., 1.], [2., 3.], [4., 5.], [6., 7.]]) >>> out = cat_frames(frames, N=3, dim=-1, time_dim=-2, padding="constant") >>> out.shape torch.Size([4, 6]) >>> out tensor([[0., 0., 0., 0., 0., 1.], [0., 0., 0., 1., 2., 3.], [0., 1., 2., 3., 4., 5.], [2., 3., 4., 5., 6., 7.]])
Note
This functional covers the offline (contiguous trajectory slice) windowing used by
CatFrames. The transform’s online path (per-step()buffer accumulation) is inherently stateful and is not expressed as a pure function.See also