• Docs >
  • Functional Programming with TensorDict
Shortcuts

Functional Programming with TensorDict

Author: Vincent Moens

In this tutorial you will learn how to use TensorDict for functional-style programming with Module, including parameter swapping, model ensembling with vmap(), and functional calls with functional_call().

TensorDict as a parameter container

from_module() extracts the parameters of a module into a nested TensorDict whose structure mirrors the module hierarchy.

import torch
import torch.nn as nn
from tensordict import TensorDict

module = nn.Sequential(nn.Linear(3, 4), nn.ReLU(), nn.Linear(4, 1))
params = TensorDict.from_module(module)
print(params)
TensorDict(
    fields={
        0: TensorDict(
            fields={
                bias: Parameter(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
                weight: Parameter(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False),
        2: TensorDict(
            fields={
                bias: Parameter(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
                weight: Parameter(shape=torch.Size([1, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

The resulting TensorDict holds the same Parameter objects as the module. We can manipulate them as a batch – for example, zeroing all parameters at once:

params_zero = params.detach().clone().zero_()
print("All zeros:", (params_zero == 0).all())
All zeros: True

Swapping parameters with a context manager

to_module() temporarily replaces the parameters of a module within a context manager. The original parameters are restored on exit.

x = torch.randn(5, 3)

with params_zero.to_module(module):
    y_zero = module(x)

print("Output with zeroed params:", y_zero)
assert (y_zero == 0).all()

y_original = module(x)
print("Output with original params:", y_original)
assert not (y_original == 0).all()
Output with zeroed params: tensor([[0.],
        [0.],
        [0.],
        [0.],
        [0.]])
Output with original params: tensor([[-0.0777],
        [ 0.1591],
        [ 0.4036],
        [-0.1975],
        [-0.0056]], grad_fn=<AddmmBackward0>)

Model ensembling with torch.vmap

Because TensorDict supports batching and stacking, we can stack multiple parameter configurations and use vmap() to run the model across all of them in a single vectorized call.

params_ones = params.detach().clone().apply_(lambda t: t.fill_(1.0))
params_stack = torch.stack([params_zero, params_ones, params])

print("Stacked params batch_size:", params_stack.batch_size)


def call(x, td):
    with td.to_module(module):
        return module(x)


x = torch.randn(3, 5, 3)
y = torch.vmap(call)(x, params_stack)
print("Output shape:", y.shape)

assert (y[0] == 0).all()
Stacked params batch_size: torch.Size([3])
Output shape: torch.Size([3, 5, 1])

Functional calls with torch.func

functional_call() works with the state-dict extracted by from_module(). Because from_module returns a TensorDict with the same structure as a state-dict, we can convert it to a regular dict and pass it directly.

from torch.func import functional_call

flat_params = params.flatten_keys(".")
state_dict = dict(flat_params.items())
x = torch.randn(5, 3)
y = functional_call(module, state_dict, x)
print("functional_call output:", y.shape)
functional_call output: torch.Size([5, 1])

The combination of from_module(), to_module(), and vmap() makes it straightforward to do things like compute per-sample gradients, run model ensembles, or implement meta-learning inner loops – all without leaving the standard PyTorch API.

Total running time of the script: (0 minutes 0.006 seconds)

Gallery generated by Sphinx-Gallery

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources