canonicalize_rnn_subset#
- class torchrl.modules.canonicalize_rnn_subset(data: TensorDictBase, modules: Iterable[LSTMModule | GRUModule], *, inplace: bool = False)[source]#
Canonicalize only the union of RNN keys used by
modules.Convenience wrapper around
LSTMModule.canonicalize()/GRUModule.canonicalize()for pipelines that feed several recurrent modules from the same TensorDict (e.g. a recurrent actor and a recurrent critic). The union of every module’scanonical_keysis collected, canonicalized once, and merged back. Other leaves are untouched.- Parameters:
data – TensorDict to canonicalize.
modules – Iterable of
LSTMModule/GRUModulewhosecanonical_keysdefine the subset to canonicalize.inplace – When
True, mutatesdatain place and returns it. Defaults toFalse.
- Returns:
A TensorDict with canonical layout on the RNN-relevant leaves.
Examples
>>> import torch >>> from tensordict import TensorDict >>> from torchrl.modules import LSTMModule, canonicalize_rnn_subset >>> actor = LSTMModule(input_size=3, hidden_size=4, in_key="obs", ... out_key="actor_h") >>> critic = LSTMModule(input_size=3, hidden_size=4, in_key="obs", ... out_key="critic_h") >>> td = TensorDict({"obs": torch.zeros(2, 5, 3)}, batch_size=[2, 5]) >>> canonicalize_rnn_subset(td, [actor, critic])["obs"].is_contiguous() True