Shortcuts

get_env_transforms_from_module

class torchrl.modules.get_env_transforms_from_module(module, init_key='is_init')[source]

Return all TransformedEnv transforms needed for a recurrent module.

Composes InitTracker (writes is_init=True at episode resets) with TensorDictPrimer (initialises hidden states). Pass the result directly to TransformedEnv.

Parameters:
  • module (torch.nn.Module) – A module that may contain recurrent submodules (e.g. LSTMModule or GRUModule).

  • init_key (str, optional) – the key used by InitTracker to mark episode starts. Must match the is_init key expected by the recurrent module. Defaults to "is_init".

Returns:

A Compose of [InitTracker, TensorDictPrimer] when the module contains recurrent submodules, or a bare InitTracker otherwise.

Example

>>> from torchrl.modules import GRUModule
>>> from torchrl.modules.utils import get_env_transforms_from_module
>>> gru = GRUModule(
...     input_size=4, hidden_size=8, num_layers=1,
...     in_keys=["obs", "recurrent_state", "is_init"],
...     out_keys=["features", ("next", "recurrent_state")],
... )
>>> transforms = get_env_transforms_from_module(gru)
>>> # TransformedEnv(base_env, transforms)

Docs

Lorem ipsum dolor sit amet, consectetur

View Docs

Tutorials

Lorem ipsum dolor sit amet, consectetur

View Tutorials

Resources

Lorem ipsum dolor sit amet, consectetur

View Resources