torchrl.trainers.algorithms.configs.modules.TensorDictModuleConfig¶
- class torchrl.trainers.algorithms.configs.modules.TensorDictModuleConfig(_partial_: bool = False, in_keys: Optional[Any] = None, out_keys: Optional[Any] = None, module: MLPConfig = '???', _target_: str = 'tensordict.nn.TensorDictModule')[source]¶
A class to configure a TensorDictModule.
Example
>>> cfg = TensorDictModuleConfig(module=MLPConfig(in_features=10, out_features=10, depth=2, num_cells=32), in_keys=["observation"], out_keys=["action"]) >>> module = instantiate(cfg) >>> assert isinstance(module, TensorDictModule) >>> assert module(observation=torch.randn(10, 10)).shape == (10, 10)
See also
tensordict.nn.TensorDictModule