get_env_transforms_from_module¶
- class torchrl.modules.get_env_transforms_from_module(module, init_key='is_init')[source]¶
Return all
TransformedEnvtransforms needed for a recurrent module.Composes
InitTracker(writesis_init=Trueat episode resets) withTensorDictPrimer(initialises hidden states). Pass the result directly toTransformedEnv.- Parameters:
module (torch.nn.Module) – A module that may contain recurrent submodules (e.g.
LSTMModuleorGRUModule).init_key (str, optional) – the key used by
InitTrackerto mark episode starts. Must match theis_initkey expected by the recurrent module. Defaults to"is_init".
- Returns:
A
Composeof[InitTracker, TensorDictPrimer]when the module contains recurrent submodules, or a bareInitTrackerotherwise.
Example
>>> from torchrl.modules import GRUModule >>> from torchrl.modules.utils import get_env_transforms_from_module >>> gru = GRUModule( ... input_size=4, hidden_size=8, num_layers=1, ... in_keys=["obs", "recurrent_state", "is_init"], ... out_keys=["features", ("next", "recurrent_state")], ... ) >>> transforms = get_env_transforms_from_module(gru) >>> # TransformedEnv(base_env, transforms)