Shortcuts

tensordict.nn package

The tensordict.nn package makes it possible to flexibly use TensorDict within ML pipelines.

Since TensorDict turns parts of one’s code to a key-based structure, it is now possible to build complex graph structures using these keys as hooks. The basic building block is TensorDictModule, which wraps an torch.nn.Module instance with a list of input and output keys:

>>> from torch.nn import Transformer
>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictModule
>>> import torch
>>> module = TensorDictModule(Transformer(), in_keys=["feature", "target"], out_keys=["prediction"])
>>> data = TensorDict({"feature": torch.randn(10, 11, 512), "target": torch.randn(10, 11, 512)}, [10, 11])
>>> data = module(data)
>>> print(data)
TensorDict(
    fields={
        feature: Tensor(torch.Size([10, 11, 512]), dtype=torch.float32),
        prediction: Tensor(torch.Size([10, 11, 512]), dtype=torch.float32),
        target: Tensor(torch.Size([10, 11, 512]), dtype=torch.float32)},
    batch_size=torch.Size([10, 11]),
    device=None,
    is_shared=False)

One does not necessarily need to use TensorDictModule, a custom torch.nn.Module with an ordered list of input and output keys (named module.in_keys and module.out_keys) will suffice.

A key pain-point of multiple PyTorch users is the inability of nn.Sequential to handle modules with multiple inputs. Working with key-based graphs can easily solve that problem as each node in the sequence knows what data needs to be read and where to write it.

For this purpose, we provide the TensorDictSequential class which passes data through a sequence of TensorDictModules. Each module in the sequence takes its input from, and writes its output to the original TensorDict, meaning it’s possible for modules in the sequence to ignore output from their predecessors, or take additional input from the tensordict as necessary. Here’s an example:

>>> from tensordict.nn import TensorDictSequential
>>> class Net(nn.Module):
...     def __init__(self, input_size=100, hidden_size=50, output_size=10):
...         super().__init__()
...         self.fc1 = nn.Linear(input_size, hidden_size)
...         self.fc2 = nn.Linear(hidden_size, output_size)
...
...     def forward(self, x):
...         x = torch.relu(self.fc1(x))
...         return self.fc2(x)
...
>>> class Masker(nn.Module):
...     def forward(self, x, mask):
...         return torch.softmax(x * mask, dim=1)
...
>>> net = TensorDictModule(
...     Net(), in_keys=[("input", "x")], out_keys=[("intermediate", "x")]
... )
>>> masker = TensorDictModule(
...     Masker(),
...     in_keys=[("intermediate", "x"), ("input", "mask")],
...     out_keys=[("output", "probabilities")],
... )
>>> module = TensorDictSequential(net, masker)
>>>
>>> td = TensorDict(
...     {
...         "input": TensorDict(
...             {"x": torch.rand(32, 100), "mask": torch.randint(2, size=(32, 10))},
...             batch_size=[32],
...         )
...     },
...     batch_size=[32],
... )
>>> td = module(td)
>>> print(td)
TensorDict(
    fields={
        input: TensorDict(
            fields={
                mask: Tensor(torch.Size([32, 10]), dtype=torch.int64),
                x: Tensor(torch.Size([32, 100]), dtype=torch.float32)},
            batch_size=torch.Size([32]),
            device=None,
            is_shared=False),
        intermediate: TensorDict(
            fields={
                x: Tensor(torch.Size([32, 10]), dtype=torch.float32)},
            batch_size=torch.Size([32]),
            device=None,
            is_shared=False),
        output: TensorDict(
            fields={
                probabilities: Tensor(torch.Size([32, 10]), dtype=torch.float32)},
            batch_size=torch.Size([32]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([32]),
    device=None,
    is_shared=False)

We can also select sub-graphs easily through the select_subsequence() method:

>>> sub_module = module.select_subsequence(out_keys=[("intermediate", "x")])
>>> td = TensorDict(
...     {
...         "input": TensorDict(
...             {"x": torch.rand(32, 100), "mask": torch.randint(2, size=(32, 10))},
...             batch_size=[32],
...         )
...     },
...     batch_size=[32],
... )
>>> sub_module(td)
>>> print(td)  # the "output" has not been computed
TensorDict(
    fields={
        input: TensorDict(
            fields={
                mask: Tensor(torch.Size([32, 10]), dtype=torch.int64),
                x: Tensor(torch.Size([32, 100]), dtype=torch.float32)},
            batch_size=torch.Size([32]),
            device=None,
            is_shared=False),
        intermediate: TensorDict(
            fields={
                x: Tensor(torch.Size([32, 10]), dtype=torch.float32)},
            batch_size=torch.Size([32]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([32]),
    device=None,
    is_shared=False)

Finally, tensordict.nn comes with a ProbabilisticTensorDictModule that allows to build distributions from network outputs and get summary statistics or samples from it (along with the distribution parameters):

>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictModule
>>> from tensordict.nn.distributions import NormalParamExtractor
>>> from tensordict.nn.prototype import (
...     ProbabilisticTensorDictModule,
...     ProbabilisticTensorDictSequential,
... )
>>> from torch.distributions import Normal
>>> td = TensorDict(
...     {"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3]
... )
>>> net = torch.nn.Sequential(torch.nn.GRUCell(4, 8), NormalParamExtractor())
>>> module = TensorDictModule(
...     net, in_keys=["input", "hidden"], out_keys=["loc", "scale"]
... )
>>> prob_module = ProbabilisticTensorDictModule(
...     in_keys=["loc", "scale"],
...     out_keys=["sample"],
...     distribution_class=Normal,
...     return_log_prob=True,
... )
>>> td_module = ProbabilisticTensorDictSequential(module, prob_module)
>>> td_module(td)
>>> print(td)
TensorDict(
    fields={
        action: Tensor(torch.Size([3, 4]), dtype=torch.float32),
        hidden: Tensor(torch.Size([3, 8]), dtype=torch.float32),
        input: Tensor(torch.Size([3, 4]), dtype=torch.float32),
        loc: Tensor(torch.Size([3, 4]), dtype=torch.float32),
        sample_log_prob: Tensor(torch.Size([3, 4]), dtype=torch.float32),
        scale: Tensor(torch.Size([3, 4]), dtype=torch.float32)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)

Type-Safe TensorClass Modules

The TensorClassModuleBase provides a type-safe way to define modules that work with TensorClass inputs and outputs. This offers compile-time type checking and improved code clarity compared to working with string-based keys.

A TensorClassModuleBase subclass specifies its input and output types through generic type parameters. The module can be converted to work with TensorDict objects using the as_td_module() method, which returns a TensorClassModuleWrapper:

>>> import torch
>>> from tensordict.tensorclass import TensorClass
>>> from tensordict.nn import TensorClassModuleBase
>>> from tensordict import TensorDict
>>>
>>> # Define input and output TensorClass types
>>> class InputTC(TensorClass):
...     a: torch.Tensor
...     b: torch.Tensor
...
>>> class OutputTC(TensorClass):
...     sum: torch.Tensor
...     difference: torch.Tensor
...
>>> # Create a type-safe module
>>> class MyModule(TensorClassModuleBase[InputTC, OutputTC]):
...     def forward(self, x: InputTC) -> OutputTC:
...         return OutputTC(
...             sum=x.a + x.b,
...             difference=x.a - x.b,
...             batch_size=x.batch_size
...         )
...
>>> # Use with TensorClass
>>> module = MyModule()
>>> input_tc = InputTC(a=torch.tensor([1.0, 2.0]), b=torch.tensor([3.0, 4.0]), batch_size=[2])
>>> output = module(input_tc)
>>> print(output.sum)
tensor([4., 6.])
>>> print(output.difference)
tensor([-2., -2.])
>>>
>>> # Convert to TensorDictModule for use in TensorDict workflows
>>> td_module = module.as_td_module()
>>> td = TensorDict({"a": torch.tensor([1.0, 2.0]), "b": torch.tensor([3.0, 4.0])}, batch_size=[2])
>>> result = td_module(td)
>>> print(result)
TensorDict(
    fields={
        a: Tensor(torch.Size([2]), dtype=torch.float32),
        b: Tensor(torch.Size([2]), dtype=torch.float32),
        difference: Tensor(torch.Size([2]), dtype=torch.float32),
        sum: Tensor(torch.Size([2]), dtype=torch.float32)},
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)

The type-safe approach offers several benefits:

  • Type checking: IDEs and type checkers can verify correct usage at development time

  • Self-documenting: The input and output structure is clear from the type signature

  • Refactoring: Renaming fields in TensorClass definitions is caught by type checkers

  • Nested structures: Support for nested TensorClass types with automatic key extraction

TensorClassModuleBase modules can be composed and used in TensorDictSequential after conversion via as_td_module().

TensorDictModuleBase(*args, **kwargs)

Base class to TensorDict modules.

TensorDictModule(*args, **kwargs)

A TensorDictModule, is a python wrapper around a nn.Module that reads and writes to a TensorDict.

TensorClassModuleBase(*args, **kwargs)

A TensorClassModuleBase is a base class for modules that operate on TensorClass instances.

TensorClassModuleWrapper(*args, **kwargs)

Wrapper class for TensorClassModuleBase objects.

ProbabilisticTensorDictModule(*args, **kwargs)

A probabilistic TD Module.

ProbabilisticTensorDictSequential(*args, ...)

A sequence of TensorDictModules containing at least one ProbabilisticTensorDictModule.

TensorDictSequential(*args, **kwargs)

A sequence of TensorDictModules.

TensorDictModuleWrapper(*args, **kwargs)

Wrapper class for TensorDictModule objects.

CudaGraphModule(module[, warmup, in_keys, ...])

A cudagraph wrapper for PyTorch callables.

WrapModule(*args, **kwargs)

A wrapper around any callable that processes TensorDict instances.

InteractionType(value)

A list of possible interaction types with a distribution.

set_interaction_type([type])

Sets all ProbabilisticTDModules sampling to the desired type.

set_composite_lp_aggregate([mode])

Controls whether CompositeDistribution log-probabilities and entropies will be aggregated in a single tensor.

composite_lp_aggregate([nowarn])

Returns whether a CompositeDistribution log-probabilities and entropies will be aggregated in a single tensor.

as_tensordict_module(*, in_keys, out_keys)

A decorator that converts a function into a TensorDictModule.

Ensembles

The functional approach enables a straightforward ensemble implementation. We can duplicate and reinitialize model copies using the tensordict.nn.EnsembleModule

>>> import torch
>>> from torch import nn
>>> from tensordict.nn import TensorDictModule
>>> from torchrl.modules import EnsembleModule
>>> from tensordict import TensorDict
>>> net = nn.Sequential(nn.Linear(4, 32), nn.ReLU(), nn.Linear(32, 2))
>>> mod = TensorDictModule(net, in_keys=['a'], out_keys=['b'])
>>> ensemble = EnsembleModule(mod, num_copies=3)
>>> data = TensorDict({'a': torch.randn(10, 4)}, batch_size=[10])
>>> ensemble(data)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([3, 10, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([3, 10, 2]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3, 10]),
    device=None,
    is_shared=False)

EnsembleModule(*args, **kwargs)

Module that wraps a module and repeats it to form an ensemble.

Compiling TensorDictModules

Since v0.5, TensorDict components are compatible with compile(). For instance, a TensorDictSequential module can be compiled with torch.compile and reach a runtime similar to a regular PyTorch module wrapped in a TensorDictModule.

Distributions

AddStateIndependentNormalScale([...])

A nn.Module that adds trainable state-independent scale parameters.

CompositeDistribution(params, ...[, ...])

A composite distribution that groups multiple distributions together using the TensorDict interface.

Delta(param[, atol, rtol, batch_shape, ...])

Delta distribution.

NormalParamExtractor([scale_mapping, scale_lb])

A non-parametric nn.Module that splits its input into loc and scale parameters.

OneHotCategorical([logits, probs])

One-hot categorical distribution.

TruncatedNormal(loc, scale, a, b[, ...])

Truncated Normal distribution.

Utils

make_tensordict([input_dict, batch_size, ...])

Returns a TensorDict created from the keyword arguments or an input dictionary.

dispatch([separator, source, dest, ...])

Allows for a function expecting a TensorDict to be called using kwargs.

inv_softplus(bias)

Inverse softplus function.

biased_softplus(bias[, min_val])

A biased softplus module.

set_skip_existing([mode, in_key_attr, ...])

A context manager for skipping existing nodes in a TensorDict graph.

skip_existing()

Returns whether or not existing entries in a tensordict should be re-computed by a module.

add_custom_mapping(name, mapping)

Adds a custom mapping to be used in mapping classes.

mappings(key)

Given an input string, returns a surjective function f(x): R -> R^+.

rand_one_hot(values[, do_softmax])

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources