Shortcuts

TensorClassModuleBase

class tensordict.nn.TensorClassModuleBase(*args: Any, **kwargs: Any)

A TensorClassModuleBase is a base class for modules that operate on TensorClass instances.

TensorClassModuleBase subclasses provide a type-safe way to define modules that work with TensorClass inputs and outputs. The class automatically extracts input and output type information from the generic type parameters.

The module can be converted to a TensorDictModule using the as_td_module() method, allowing it to be used in TensorDict-based workflows.

Type Parameters:

InputClass: The input type, must be a TensorClass or Tensor. OutputClass: The output type, must be a TensorClass or Tensor.

Variables:
  • input_type (type[InputClass]) – The input type class.

  • output_type (type[OutputClass]) – The output type class.

Examples

>>> from tensordict.tensorclass import TensorClass
>>> from tensordict.nn import TensorClassModuleBase
>>> import torch
>>>
>>> class InputTC(TensorClass):
...     a: torch.Tensor
...     b: torch.Tensor
...
>>> class OutputTC(TensorClass):
...     result: torch.Tensor
...
>>> class AddModule(TensorClassModuleBase[InputTC, OutputTC]):
...     def forward(self, x: InputTC) -> OutputTC:
...         return OutputTC(
...             result=x.a + x.b,
...             batch_size=x.batch_size
...         )
...
>>> module = AddModule()
>>> input_tc = InputTC(a=torch.tensor([1.0]), b=torch.tensor([2.0]), batch_size=[1])
>>> output = module(input_tc)
>>> assert output.result == torch.tensor([3.0])
as_td_module() TensorClassModuleWrapper

Convert this module to a TensorDictModule.

This method wraps the TensorClassModuleBase in a TensorClassModuleWrapper, allowing it to be used with TensorDict inputs and outputs.

Returns:

A wrapper that converts between TensorDict

and TensorClass representations.

Return type:

TensorClassModuleWrapper

Raises:

ValueError – If either input_type or output_type is not a TensorClass.

Examples

>>> from tensordict import TensorDict
>>> from tensordict.tensorclass import TensorClass
>>> from tensordict.nn import TensorClassModuleBase
>>> import torch
>>>
>>> class InputTC(TensorClass):
...     x: torch.Tensor
...
>>> class OutputTC(TensorClass):
...     y: torch.Tensor
...
>>> class MyModule(TensorClassModuleBase[InputTC, OutputTC]):
...     def forward(self, input: InputTC) -> OutputTC:
...         return OutputTC(y=input.x * 2, batch_size=input.batch_size)
...
>>> module = MyModule()
>>> td_module = module.as_td_module()
>>> td = TensorDict({"x": torch.ones(3)}, batch_size=[3])
>>> result = td_module(td)
>>> assert (result["y"] == 2).all()
abstract forward(x: InputClass) OutputClass

Forward pass of the module.

Parameters:

x (InputClass) – Input instance.

Returns:

Output instance.

Return type:

OutputClass

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources