DecodeVideoTransform#
- class torchrl.envs.transforms.DecodeVideoTransform(*, in_keys: Sequence[NestedKey], out_keys: Sequence[NestedKey] | None = None, device: Any = None, dtype: Any = None)[source]#
Decodes
VideoClipRefleaves to dense frame tensors.This is a forward / sample-path transform: it reads the lazy video references found at
in_keysand writes the decodeduint8frames atout_keys. It is meant to be appended to aReplayBufferso that indexing the buffer stays cheap (no materialized frames) whilerb.sample()returns decoded frames aligned to the sampled steps. It is a read-side codec, so no inverse is defined.Decoding is delegated to
VideoClipRef.decode(), which groups the sampled references by source file and uses ranged reads for contiguous indices. This is what makes it compose withSliceSampler: a contiguous window of sampled steps maps to consecutive frame indices and decodes as a single ranged read per source.- Keyword Arguments:
in_keys (sequence of NestedKey) – the keys holding the
VideoClipRefleaves to decode.out_keys (sequence of NestedKey, optional) – destination keys for the decoded frames. Defaults to
in_keys(in-place replacement).device (torch.device or str, optional) – device for the decoded frames. A CUDA device enables GPU (NVDEC) decoding. Defaults to
None(uses the reference’sout_device, else CPU).dtype (torch.dtype, optional) – dtype for the decoded frames. Defaults to
None(uses the reference’sout_dtype, elseuint8).
Note
This transform requires torchcodec. The lightweight
VideoClipRefleaves stored in the buffer are picklable and hold no open decoder; decoders are opened lazily and cached per worker process.Examples
>>> import tempfile, os, torch >>> from torchcodec.encoders import VideoEncoder >>> from tensordict import TensorDict >>> from torchrl.data import ( ... LazyTensorStorage, ReplayBuffer, SliceSampler, VideoClipRef) >>> from torchrl.envs.transforms import DecodeVideoTransform >>> frames = torch.arange(20, dtype=torch.uint8).reshape(20, 1, 1, 1) >>> frames = frames.expand(20, 3, 8, 8).contiguous() >>> path = os.path.join(tempfile.mkdtemp(), "clip.mp4") >>> VideoEncoder(frames=frames, frame_rate=10).to_file(path) >>> ref = VideoClipRef.from_file(path) # 20 frames, lazy >>> data = TensorDict( ... {"frame": ref, "episode": torch.zeros(20, dtype=torch.long)}, ... batch_size=[20], ... ) >>> rb = ReplayBuffer( ... storage=LazyTensorStorage(20), ... sampler=SliceSampler(slice_len=4, traj_key="episode"), ... batch_size=8, ... transform=DecodeVideoTransform(in_keys=["frame"], out_keys=["pixels"]), ... ) >>> _ = rb.extend(data) >>> sample = rb.sample() >>> sample["pixels"].shape # decoded on sample torch.Size([8, 3, 8, 8])
See also