Note
Go to the end to download the full example code.
Functional Programming with TensorDict¶
Author: Vincent Moens
In this tutorial you will learn how to use TensorDict for
functional-style programming with Module, including
parameter swapping, model ensembling with vmap(), and
functional calls with functional_call().
TensorDict as a parameter container¶
from_module() extracts the parameters of a module into a
nested TensorDict whose structure mirrors the module hierarchy.
TensorDict(
fields={
0: TensorDict(
fields={
bias: Parameter(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
weight: Parameter(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False),
2: TensorDict(
fields={
bias: Parameter(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
weight: Parameter(shape=torch.Size([1, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
The resulting TensorDict holds the same
Parameter objects as the module. We can manipulate them
as a batch – for example, zeroing all parameters at once:
params_zero = params.detach().clone().zero_()
print("All zeros:", (params_zero == 0).all())
All zeros: True
Swapping parameters with a context manager¶
to_module() temporarily replaces the parameters of a
module within a context manager. The original parameters are restored on
exit.
x = torch.randn(5, 3)
with params_zero.to_module(module):
y_zero = module(x)
print("Output with zeroed params:", y_zero)
assert (y_zero == 0).all()
y_original = module(x)
print("Output with original params:", y_original)
assert not (y_original == 0).all()
Output with zeroed params: tensor([[0.],
[0.],
[0.],
[0.],
[0.]])
Output with original params: tensor([[-0.0777],
[ 0.1591],
[ 0.4036],
[-0.1975],
[-0.0056]], grad_fn=<AddmmBackward0>)
Model ensembling with torch.vmap¶
Because TensorDict supports batching and stacking, we can stack
multiple parameter configurations and use vmap() to run the
model across all of them in a single vectorized call.
params_ones = params.detach().clone().apply_(lambda t: t.fill_(1.0))
params_stack = torch.stack([params_zero, params_ones, params])
print("Stacked params batch_size:", params_stack.batch_size)
def call(x, td):
with td.to_module(module):
return module(x)
x = torch.randn(3, 5, 3)
y = torch.vmap(call)(x, params_stack)
print("Output shape:", y.shape)
assert (y[0] == 0).all()
Stacked params batch_size: torch.Size([3])
Output shape: torch.Size([3, 5, 1])
Functional calls with torch.func¶
functional_call() works with the state-dict extracted
by from_module(). Because from_module returns a
TensorDict with the same structure as a state-dict, we can
convert it to a regular dict and pass it directly.
from torch.func import functional_call
flat_params = params.flatten_keys(".")
state_dict = dict(flat_params.items())
x = torch.randn(5, 3)
y = functional_call(module, state_dict, x)
print("functional_call output:", y.shape)
functional_call output: torch.Size([5, 1])
The combination of from_module(),
to_module(), and vmap() makes it
straightforward to do things like compute per-sample gradients, run
model ensembles, or implement meta-learning inner loops – all without
leaving the standard PyTorch API.
Total running time of the script: (0 minutes 0.006 seconds)