TensorClassModuleWrapper¶
- class tensordict.nn.TensorClassModuleWrapper(*args, **kwargs)¶
Wrapper class for TensorClassModuleBase objects.
This wrapper allows TensorClassModuleBase instances to be used in TensorDict-based workflows by handling the conversion between TensorDict and TensorClass representations. When called with a TensorDict, the wrapper converts it to a TensorClass, passes it through the wrapped module, and converts the output back to a TensorDict.
- Parameters:
module (TensorClassModuleBase) – The TensorClassModuleBase instance to wrap.
Examples
>>> from tensordict import TensorDict >>> from tensordict.tensorclass import TensorClass >>> from tensordict.nn import TensorClassModuleBase >>> import torch >>> >>> class InputTC(TensorClass): ... x: torch.Tensor ... >>> class OutputTC(TensorClass): ... y: torch.Tensor ... >>> class MyModule(TensorClassModuleBase[InputTC, OutputTC]): ... def forward(self, input: InputTC) -> OutputTC: ... return OutputTC(y=input.x + 1, batch_size=input.batch_size) ... >>> module = MyModule() >>> td_module = module.as_td_module() >>> td = TensorDict({"x": torch.zeros(3)}, batch_size=[3]) >>> result = td_module(td) >>> assert "y" in result
- forward(tensordict: TensorDict = None, args=None, **kwargs) TensorDict¶
Forward pass converting TensorDict to TensorClass and back.
- Parameters:
tensordict (TensorDict) – Input tensordict.
*args – Additional positional arguments.
**kwargs – Additional keyword arguments.
- Returns:
Output tensordict.
- Return type: