Shortcuts

TensorClassModuleWrapper

class tensordict.nn.TensorClassModuleWrapper(*args, **kwargs)

Wrapper class for TensorClassModuleBase objects.

This wrapper allows TensorClassModuleBase instances to be used in TensorDict-based workflows by handling the conversion between TensorDict and TensorClass representations. When called with a TensorDict, the wrapper converts it to a TensorClass, passes it through the wrapped module, and converts the output back to a TensorDict.

Parameters:

module (TensorClassModuleBase) – The TensorClassModuleBase instance to wrap.

Examples

>>> from tensordict import TensorDict
>>> from tensordict.tensorclass import TensorClass
>>> from tensordict.nn import TensorClassModuleBase
>>> import torch
>>>
>>> class InputTC(TensorClass):
...     x: torch.Tensor
...
>>> class OutputTC(TensorClass):
...     y: torch.Tensor
...
>>> class MyModule(TensorClassModuleBase[InputTC, OutputTC]):
...     def forward(self, input: InputTC) -> OutputTC:
...         return OutputTC(y=input.x + 1, batch_size=input.batch_size)
...
>>> module = MyModule()
>>> td_module = module.as_td_module()
>>> td = TensorDict({"x": torch.zeros(3)}, batch_size=[3])
>>> result = td_module(td)
>>> assert "y" in result
forward(tensordict: TensorDict = None, args=None, **kwargs) TensorDict

Forward pass converting TensorDict to TensorClass and back.

Parameters:
  • tensordict (TensorDict) – Input tensordict.

  • *args – Additional positional arguments.

  • **kwargs – Additional keyword arguments.

Returns:

Output tensordict.

Return type:

TensorDict

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources