Note
Go to the end to download the full example code.
TensorDictModule¶
In this tutorial you will learn how to use TensorDictModule and
TensorDictSequential to create generic and reusable modules that can accept
TensorDict as input.
For a convenient usage of the TensorDict class with nn.Module,
tensordict provides an interface between the two named TensorDictModule.
The TensorDictModule class is an nn.Module that takes a
TensorDict as input when called.
It is up to the user to define the keys to be read as input and output.
TensorDictModule by examples¶
import torch
import torch.nn as nn
from tensordict import TensorDict
from tensordict.nn import TensorDictModule, TensorDictSequential
Example 1: Simple usage¶
We have a TensorDict with 2 entries "a" and "b" but only the
value associated with "a" has to be read by the network.
tensordict = TensorDict(
{"a": torch.randn(5, 3), "b": torch.zeros(5, 4, 3)},
batch_size=[5],
)
linear = TensorDictModule(nn.Linear(3, 10), in_keys=["a"], out_keys=["a_out"])
linear(tensordict)
assert (tensordict.get("b") == 0).all()
print(tensordict)
TensorDict(
fields={
a: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False),
a_out: Tensor(shape=torch.Size([5, 10]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([5, 4, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([5]),
device=None,
is_shared=False)
Example 2: Multiple inputs¶
Suppose we have a slightly more complex network that takes 2 entries and
averages them into a single output tensor. To make a TensorDictModule
instance read multiple input values, one must register them in the
in_keys keyword argument of the constructor.
tensordict = TensorDict(
{
"a": torch.randn(5, 3),
"b": torch.randn(5, 4),
},
batch_size=[5],
)
mergelinear = TensorDictModule(
MergeLinear(3, 4, 10), in_keys=["a", "b"], out_keys=["output"]
)
mergelinear(tensordict)
TensorDict(
fields={
a: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False),
output: Tensor(shape=torch.Size([5, 10]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([5]),
device=None,
is_shared=False)
Example 3: Multiple outputs¶
Similarly, TensorDictModule not only supports multiple inputs but also
multiple outputs. To make a TensorDictModule instance write to multiple
output values, one must register them in the out_keys keyword argument
of the constructor.
tensordict = TensorDict({"a": torch.randn(5, 3)}, batch_size=[5])
splitlinear = TensorDictModule(
MultiHeadLinear(3, 4, 10),
in_keys=["a"],
out_keys=["output_1", "output_2"],
)
splitlinear(tensordict)
TensorDict(
fields={
a: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False),
output_1: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False),
output_2: Tensor(shape=torch.Size([5, 10]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([5]),
device=None,
is_shared=False)
When having multiple input keys and output keys, make sure they match the order in the module.
TensorDictModule can work with TensorDict instances that contain
more tensors than what the in_keys attribute indicates.
Unless a vmap operator is used, the TensorDict is modified in-place.
Ignoring some outputs
Note that it is possible to avoid writing some of the tensors to the
TensorDict output, using "_" in out_keys.
Example 4: Combining multiple TensorDictModule with TensorDictSequential¶
To combine multiple TensorDictModule instances, we can use
TensorDictSequential. We create a list where each TensorDictModule must
be executed sequentially. TensorDictSequential will read and write keys to the
tensordict following the sequence of modules provided.
We can also gather the inputs needed by TensorDictSequential with the
in_keys property, and the outputs keys are found at the out_keys attribute.
tensordict = TensorDict({"a": torch.randn(5, 3)}, batch_size=[5])
splitlinear = TensorDictModule(
MultiHeadLinear(3, 4, 10),
in_keys=["a"],
out_keys=["output_1", "output_2"],
)
mergelinear = TensorDictModule(
MergeLinear(4, 10, 13),
in_keys=["output_1", "output_2"],
out_keys=["output"],
)
split_and_merge_linear = TensorDictSequential(splitlinear, mergelinear)
assert split_and_merge_linear(tensordict)["output"].shape == torch.Size([5, 13])
Do’s and don’t with TensorDictModule¶
Don’t use nn.Sequence, similar to nn.Module, it would break features
such as functorch compatibility. Do use TensorDictSequential instead.
Don’t assign the output tensordict to a new variable, as the output tensordict is just the input modified in-place:
tensordict = module(tensordict) # ok!
tensordict_out = module(tensordict) # don’t!
ProbabilisticTensorDictModule¶
ProbabilisticTensorDictModule is a non-parametric module representing a
probability distribution. Distribution parameters are read from tensordict
input, and the output is written to an output tensordict. The output is
sampled given some rule, specified by the input default_interaction_type
argument and the exploration_mode() global function. If they conflict,
the context manager precedes.
It can be wired together with a TensorDictModule that returns
a tensordict updated with the distribution parameters using
ProbabilisticTensorDictSequential. This is a special case of
TensorDictSequential that terminates in a
ProbabilisticTensorDictModule.
ProbabilisticTensorDictModule is responsible for constructing the
distribution (through the get_dist() method) and/or sampling from this
distribution (through a regular __call__() to the module). The same
get_dist() method is exposed on ``ProbabilisticTensorDictSequential.
One can find the parameters in the output tensordict as well as the log probability if needed.
from tensordict.nn import (
ProbabilisticTensorDictModule,
ProbabilisticTensorDictSequential,
)
from tensordict.nn.distributions import NormalParamExtractor
from torch import distributions as dist
td = TensorDict({"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3])
net = torch.nn.GRUCell(4, 8)
net = TensorDictModule(net, in_keys=["input", "hidden"], out_keys=["hidden"])
extractor = NormalParamExtractor()
extractor = TensorDictModule(extractor, in_keys=["hidden"], out_keys=["loc", "scale"])
td_module = ProbabilisticTensorDictSequential(
net,
extractor,
ProbabilisticTensorDictModule(
in_keys=["loc", "scale"],
out_keys=["action"],
distribution_class=dist.Normal,
return_log_prob=True,
),
)
print(f"TensorDict before going through module: {td}")
td_module(td)
print(f"TensorDict after going through module now as keys action, loc and scale: {td}")
TensorDict before going through module: TensorDict(
fields={
hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3]),
device=None,
is_shared=False)
TensorDict after going through module now as keys action, loc and scale: TensorDict(
fields={
action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
sample_log_prob: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3]),
device=None,
is_shared=False)
Showcase: Implementing a transformer using TensorDictModule¶
To demonstrate the flexibility of TensorDictModule, we are going to
create a transformer that reads TensorDict objects using TensorDictModule.
The following figure shows the classical transformer architecture (Vaswani et al, 2017).
We have let the positional encoders aside for simplicity.
Let’s re-write the classical transformers blocks:
class TokensToQKV(nn.Module):
def __init__(self, to_dim, from_dim, latent_dim):
super().__init__()
self.q = nn.Linear(to_dim, latent_dim)
self.k = nn.Linear(from_dim, latent_dim)
self.v = nn.Linear(from_dim, latent_dim)
def forward(self, X_to, X_from):
Q = self.q(X_to)
K = self.k(X_from)
V = self.v(X_from)
return Q, K, V
class SplitHeads(nn.Module):
def __init__(self, num_heads):
super().__init__()
self.num_heads = num_heads
def forward(self, Q, K, V):
batch_size, to_num, latent_dim = Q.shape
_, from_num, _ = K.shape
d_tensor = latent_dim // self.num_heads
Q = Q.reshape(batch_size, to_num, self.num_heads, d_tensor).transpose(1, 2)
K = K.reshape(batch_size, from_num, self.num_heads, d_tensor).transpose(1, 2)
V = V.reshape(batch_size, from_num, self.num_heads, d_tensor).transpose(1, 2)
return Q, K, V
class Attention(nn.Module):
def __init__(self, latent_dim, to_dim):
super().__init__()
self.softmax = nn.Softmax(dim=-1)
self.out = nn.Linear(latent_dim, to_dim)
def forward(self, Q, K, V):
batch_size, n_heads, to_num, d_in = Q.shape
attn = self.softmax(Q @ K.transpose(2, 3) / d_in)
out = attn @ V
out = self.out(out.transpose(1, 2).reshape(batch_size, to_num, n_heads * d_in))
return out, attn
class SkipLayerNorm(nn.Module):
def __init__(self, to_len, to_dim):
super().__init__()
self.layer_norm = nn.LayerNorm((to_len, to_dim))
def forward(self, x_0, x_1):
return self.layer_norm(x_0 + x_1)
class FFN(nn.Module):
def __init__(self, to_dim, hidden_dim, dropout_rate=0.2):
super().__init__()
self.FFN = nn.Sequential(
nn.Linear(to_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, to_dim),
nn.Dropout(dropout_rate),
)
def forward(self, X):
return self.FFN(X)
class AttentionBlock(nn.Module):
def __init__(self, to_dim, to_len, from_dim, latent_dim, num_heads):
super().__init__()
self.tokens_to_qkv = TokensToQKV(to_dim, from_dim, latent_dim)
self.split_heads = SplitHeads(num_heads)
self.attention = Attention(latent_dim, to_dim)
self.skip = SkipLayerNorm(to_len, to_dim)
def forward(self, X_to, X_from):
Q, K, V = self.tokens_to_qkv(X_to, X_from)
Q, K, V = self.split_heads(Q, K, V)
out, attention = self.attention(Q, K, V)
out = self.skip(X_to, out)
return out
class EncoderTransformerBlock(nn.Module):
def __init__(self, to_dim, to_len, latent_dim, num_heads):
super().__init__()
self.attention_block = AttentionBlock(
to_dim, to_len, to_dim, latent_dim, num_heads
)
self.FFN = FFN(to_dim, 4 * to_dim)
self.skip = SkipLayerNorm(to_len, to_dim)
def forward(self, X_to):
X_to = self.attention_block(X_to, X_to)
X_out = self.FFN(X_to)
return self.skip(X_out, X_to)
class DecoderTransformerBlock(nn.Module):
def __init__(self, to_dim, to_len, from_dim, latent_dim, num_heads):
super().__init__()
self.attention_block = AttentionBlock(
to_dim, to_len, from_dim, latent_dim, num_heads
)
self.encoder_block = EncoderTransformerBlock(
to_dim, to_len, latent_dim, num_heads
)
def forward(self, X_to, X_from):
X_to = self.attention_block(X_to, X_from)
X_to = self.encoder_block(X_to)
return X_to
class TransformerEncoder(nn.Module):
def __init__(self, num_blocks, to_dim, to_len, latent_dim, num_heads):
super().__init__()
self.encoder = nn.ModuleList(
[
EncoderTransformerBlock(to_dim, to_len, latent_dim, num_heads)
for i in range(num_blocks)
]
)
def forward(self, X_to):
for i in range(len(self.encoder)):
X_to = self.encoder[i](X_to)
return X_to
class TransformerDecoder(nn.Module):
def __init__(self, num_blocks, to_dim, to_len, from_dim, latent_dim, num_heads):
super().__init__()
self.decoder = nn.ModuleList(
[
DecoderTransformerBlock(to_dim, to_len, from_dim, latent_dim, num_heads)
for i in range(num_blocks)
]
)
def forward(self, X_to, X_from):
for i in range(len(self.decoder)):
X_to = self.decoder[i](X_to, X_from)
return X_to
class Transformer(nn.Module):
def __init__(
self, num_blocks, to_dim, to_len, from_dim, from_len, latent_dim, num_heads
):
super().__init__()
self.encoder = TransformerEncoder(
num_blocks, to_dim, to_len, latent_dim, num_heads
)
self.decoder = TransformerDecoder(
num_blocks, from_dim, from_len, to_dim, latent_dim, num_heads
)
def forward(self, X_to, X_from):
X_to = self.encoder(X_to)
X_out = self.decoder(X_from, X_to)
return X_out
We first create the AttentionBlockTensorDict, the attention block using
TensorDictModule and TensorDictSequential.
The wiring operation that connects the modules to each other requires us
to indicate which key each of them must read and write. Unlike
nn.Sequence, a TensorDictSequential can read/write more than one
input/output. Moreover, its components inputs need not be identical to the
previous layers outputs, allowing us to code complicated neural architecture.
class AttentionBlockTensorDict(TensorDictSequential):
def __init__(
self,
to_name,
from_name,
to_dim,
to_len,
from_dim,
latent_dim,
num_heads,
):
super().__init__(
TensorDictModule(
TokensToQKV(to_dim, from_dim, latent_dim),
in_keys=[to_name, from_name],
out_keys=["Q", "K", "V"],
),
TensorDictModule(
SplitHeads(num_heads),
in_keys=["Q", "K", "V"],
out_keys=["Q", "K", "V"],
),
TensorDictModule(
Attention(latent_dim, to_dim),
in_keys=["Q", "K", "V"],
out_keys=["X_out", "Attn"],
),
TensorDictModule(
SkipLayerNorm(to_len, to_dim),
in_keys=[to_name, "X_out"],
out_keys=[to_name],
),
)
We build the encoder and decoder blocks that will be part of the transformer
thanks to TensorDictModule.
class TransformerBlockEncoderTensorDict(TensorDictSequential):
def __init__(
self,
to_name,
from_name,
to_dim,
to_len,
from_dim,
latent_dim,
num_heads,
):
super().__init__(
AttentionBlockTensorDict(
to_name,
from_name,
to_dim,
to_len,
from_dim,
latent_dim,
num_heads,
),
TensorDictModule(
FFN(to_dim, 4 * to_dim),
in_keys=[to_name],
out_keys=["X_out"],
),
TensorDictModule(
SkipLayerNorm(to_len, to_dim),
in_keys=[to_name, "X_out"],
out_keys=[to_name],
),
)
class TransformerBlockDecoderTensorDict(TensorDictSequential):
def __init__(
self,
to_name,
from_name,
to_dim,
to_len,
from_dim,
latent_dim,
num_heads,
):
super().__init__(
AttentionBlockTensorDict(
to_name,
to_name,
to_dim,
to_len,
to_dim,
latent_dim,
num_heads,
),
TransformerBlockEncoderTensorDict(
to_name,
from_name,
to_dim,
to_len,
from_dim,
latent_dim,
num_heads,
),
)
We create the transformer encoder and decoder.
For an encoder, we just need to take the same tokens for both queries, keys and values.
For a decoder, we now can extract info from X_from into X_to.
X_from will map to queries whereas X_from will map to keys and values.
class TransformerEncoderTensorDict(TensorDictSequential):
def __init__(
self,
num_blocks,
to_name,
from_name,
to_dim,
to_len,
from_dim,
latent_dim,
num_heads,
):
super().__init__(
*[
TransformerBlockEncoderTensorDict(
to_name,
from_name,
to_dim,
to_len,
from_dim,
latent_dim,
num_heads,
)
for _ in range(num_blocks)
]
)
class TransformerDecoderTensorDict(TensorDictSequential):
def __init__(
self,
num_blocks,
to_name,
from_name,
to_dim,
to_len,
from_dim,
latent_dim,
num_heads,
):
super().__init__(
*[
TransformerBlockDecoderTensorDict(
to_name,
from_name,
to_dim,
to_len,
from_dim,
latent_dim,
num_heads,
)
for _ in range(num_blocks)
]
)
class TransformerTensorDict(TensorDictSequential):
def __init__(
self,
num_blocks,
to_name,
from_name,
to_dim,
to_len,
from_dim,
from_len,
latent_dim,
num_heads,
):
super().__init__(
TransformerEncoderTensorDict(
num_blocks,
to_name,
to_name,
to_dim,
to_len,
to_dim,
latent_dim,
num_heads,
),
TransformerDecoderTensorDict(
num_blocks,
from_name,
to_name,
from_dim,
from_len,
to_dim,
latent_dim,
num_heads,
),
)
We now test our new TransformerTensorDict.
to_dim = 5
from_dim = 6
latent_dim = 10
to_len = 3
from_len = 10
batch_size = 8
num_heads = 2
num_blocks = 6
tokens = TensorDict(
{
"X_encode": torch.randn(batch_size, to_len, to_dim),
"X_decode": torch.randn(batch_size, from_len, from_dim),
},
batch_size=[batch_size],
)
transformer = TransformerTensorDict(
num_blocks,
"X_encode",
"X_decode",
to_dim,
to_len,
from_dim,
from_len,
latent_dim,
num_heads,
)
transformer(tokens)
tokens
TensorDict(
fields={
Attn: Tensor(shape=torch.Size([8, 2, 10, 3]), device=cpu, dtype=torch.float32, is_shared=False),
K: Tensor(shape=torch.Size([8, 2, 3, 5]), device=cpu, dtype=torch.float32, is_shared=False),
Q: Tensor(shape=torch.Size([8, 2, 10, 5]), device=cpu, dtype=torch.float32, is_shared=False),
V: Tensor(shape=torch.Size([8, 2, 3, 5]), device=cpu, dtype=torch.float32, is_shared=False),
X_decode: Tensor(shape=torch.Size([8, 10, 6]), device=cpu, dtype=torch.float32, is_shared=False),
X_encode: Tensor(shape=torch.Size([8, 3, 5]), device=cpu, dtype=torch.float32, is_shared=False),
X_out: Tensor(shape=torch.Size([8, 10, 6]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([8]),
device=None,
is_shared=False)
We’ve achieved to create a transformer with TensorDictModule. This
shows that TensorDictModule is a flexible module that can implement
complex operarations.
Benchmarking¶
to_dim = 5
from_dim = 6
latent_dim = 10
to_len = 3
from_len = 10
batch_size = 8
num_heads = 2
num_blocks = 6
td_tokens = TensorDict(
{
"X_encode": torch.randn(batch_size, to_len, to_dim),
"X_decode": torch.randn(batch_size, from_len, from_dim),
},
batch_size=[batch_size],
)
tdtransformer = TransformerTensorDict(
num_blocks,
"X_encode",
"X_decode",
to_dim,
to_len,
from_dim,
from_len,
latent_dim,
num_heads,
)
transformer = Transformer(
num_blocks, to_dim, to_len, from_dim, from_len, latent_dim, num_heads
)
Inference Time
import time
Execution time: 0.009625911712646484 seconds
Execution time: 0.006480216979980469 seconds
We can see on this minimal example that the overhead introduced by
TensorDictModule is marginal.
Have fun with TensorDictModule!
Total running time of the script: (0 minutes 10.088 seconds)