.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorials/tensordict_module.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials_tensordict_module.py: TensorDictModule ================ In this tutorial you will learn how to use :class:`~.TensorDictModule` and :class:`~.TensorDictSequential` to create generic and reusable modules that can accept :class:`~.TensorDict` as input. .. GENERATED FROM PYTHON SOURCE LINES 10-18 For a convenient usage of the :class:`~.TensorDict` class with ``nn.Module``, :mod:`tensordict` provides an interface between the two named ``TensorDictModule``. The ``TensorDictModule`` class is an ``nn.Module`` that takes a :class:`~.TensorDict` as input when called. It is up to the user to define the keys to be read as input and output. TensorDictModule by examples ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. GENERATED FROM PYTHON SOURCE LINES 18-25 .. code-block:: Python import torch import torch.nn as nn from tensordict import TensorDict from tensordict.nn import TensorDictModule, TensorDictSequential .. GENERATED FROM PYTHON SOURCE LINES 31-35 Example 1: Simple usage -------------------------------------- We have a :class:`~.TensorDict` with 2 entries ``"a"`` and ``"b"`` but only the value associated with ``"a"`` has to be read by the network. .. GENERATED FROM PYTHON SOURCE LINES 35-45 .. code-block:: Python tensordict = TensorDict( {"a": torch.randn(5, 3), "b": torch.zeros(5, 4, 3)}, batch_size=[5], ) linear = TensorDictModule(nn.Linear(3, 10), in_keys=["a"], out_keys=["a_out"]) linear(tensordict) assert (tensordict.get("b") == 0).all() print(tensordict) .. rst-class:: sphx-glr-script-out .. code-block:: none TensorDict( fields={ a: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False), a_out: Tensor(shape=torch.Size([5, 10]), device=cpu, dtype=torch.float32, is_shared=False), b: Tensor(shape=torch.Size([5, 4, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([5]), device=None, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 46-52 Example 2: Multiple inputs -------------------------------------- Suppose we have a slightly more complex network that takes 2 entries and averages them into a single output tensor. To make a ``TensorDictModule`` instance read multiple input values, one must register them in the ``in_keys`` keyword argument of the constructor. .. GENERATED FROM PYTHON SOURCE LINES 52-64 .. code-block:: Python class MergeLinear(nn.Module): def __init__(self, in_1, in_2, out): super().__init__() self.linear_1 = nn.Linear(in_1, out) self.linear_2 = nn.Linear(in_2, out) def forward(self, x_1, x_2): return (self.linear_1(x_1) + self.linear_2(x_2)) / 2 .. GENERATED FROM PYTHON SOURCE LINES 65-80 .. code-block:: Python tensordict = TensorDict( { "a": torch.randn(5, 3), "b": torch.randn(5, 4), }, batch_size=[5], ) mergelinear = TensorDictModule( MergeLinear(3, 4, 10), in_keys=["a", "b"], out_keys=["output"] ) mergelinear(tensordict) .. rst-class:: sphx-glr-script-out .. code-block:: none TensorDict( fields={ a: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False), b: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False), output: Tensor(shape=torch.Size([5, 10]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([5]), device=None, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 81-87 Example 3: Multiple outputs -------------------------------------- Similarly, ``TensorDictModule`` not only supports multiple inputs but also multiple outputs. To make a ``TensorDictModule`` instance write to multiple output values, one must register them in the ``out_keys`` keyword argument of the constructor. .. GENERATED FROM PYTHON SOURCE LINES 87-99 .. code-block:: Python class MultiHeadLinear(nn.Module): def __init__(self, in_1, out_1, out_2): super().__init__() self.linear_1 = nn.Linear(in_1, out_1) self.linear_2 = nn.Linear(in_1, out_2) def forward(self, x): return self.linear_1(x), self.linear_2(x) .. GENERATED FROM PYTHON SOURCE LINES 100-110 .. code-block:: Python tensordict = TensorDict({"a": torch.randn(5, 3)}, batch_size=[5]) splitlinear = TensorDictModule( MultiHeadLinear(3, 4, 10), in_keys=["a"], out_keys=["output_1", "output_2"], ) splitlinear(tensordict) .. rst-class:: sphx-glr-script-out .. code-block:: none TensorDict( fields={ a: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False), output_1: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False), output_2: Tensor(shape=torch.Size([5, 10]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([5]), device=None, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 111-133 When having multiple input keys and output keys, make sure they match the order in the module. ``TensorDictModule`` can work with :class:`~.TensorDict` instances that contain more tensors than what the ``in_keys`` attribute indicates. Unless a ``vmap`` operator is used, the :class:`~.TensorDict` is modified in-place. **Ignoring some outputs** Note that it is possible to avoid writing some of the tensors to the :class:`~.TensorDict` output, using ``"_"`` in ``out_keys``. Example 4: Combining multiple ``TensorDictModule`` with ``TensorDictSequential`` ---------------------------------------------------------------------------------- To combine multiple ``TensorDictModule`` instances, we can use ``TensorDictSequential``. We create a list where each ``TensorDictModule`` must be executed sequentially. ``TensorDictSequential`` will read and write keys to the tensordict following the sequence of modules provided. We can also gather the inputs needed by ``TensorDictSequential`` with the ``in_keys`` property, and the outputs keys are found at the ``out_keys`` attribute. .. GENERATED FROM PYTHON SOURCE LINES 133-151 .. code-block:: Python tensordict = TensorDict({"a": torch.randn(5, 3)}, batch_size=[5]) splitlinear = TensorDictModule( MultiHeadLinear(3, 4, 10), in_keys=["a"], out_keys=["output_1", "output_2"], ) mergelinear = TensorDictModule( MergeLinear(4, 10, 13), in_keys=["output_1", "output_2"], out_keys=["output"], ) split_and_merge_linear = TensorDictSequential(splitlinear, mergelinear) assert split_and_merge_linear(tensordict)["output"].shape == torch.Size([5, 13]) .. GENERATED FROM PYTHON SOURCE LINES 152-187 Do's and don't with TensorDictModule ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Don't use ``nn.Sequence``, similar to ``nn.Module``, it would break features such as ``functorch`` compatibility. Do use ``TensorDictSequential`` instead. Don't assign the output tensordict to a new variable, as the output tensordict is just the input modified in-place: tensordict = module(tensordict) # ok! tensordict_out = module(tensordict) # don't! ``ProbabilisticTensorDictModule`` ---------------------------------- ``ProbabilisticTensorDictModule`` is a non-parametric module representing a probability distribution. Distribution parameters are read from tensordict input, and the output is written to an output tensordict. The output is sampled given some rule, specified by the input ``default_interaction_type`` argument and the ``exploration_mode()`` global function. If they conflict, the context manager precedes. It can be wired together with a ``TensorDictModule`` that returns a tensordict updated with the distribution parameters using ``ProbabilisticTensorDictSequential``. This is a special case of ``TensorDictSequential`` that terminates in a ``ProbabilisticTensorDictModule``. ``ProbabilisticTensorDictModule`` is responsible for constructing the distribution (through the ``get_dist()`` method) and/or sampling from this distribution (through a regular ``__call__()`` to the module). The same ``get_dist()`` method is exposed on ``ProbabilisticTensorDictSequential. One can find the parameters in the output tensordict as well as the log probability if needed. .. GENERATED FROM PYTHON SOURCE LINES 187-214 .. code-block:: Python from tensordict.nn import ( ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential, ) from tensordict.nn.distributions import NormalParamExtractor from torch import distributions as dist td = TensorDict({"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3]) net = torch.nn.GRUCell(4, 8) net = TensorDictModule(net, in_keys=["input", "hidden"], out_keys=["hidden"]) extractor = NormalParamExtractor() extractor = TensorDictModule(extractor, in_keys=["hidden"], out_keys=["loc", "scale"]) td_module = ProbabilisticTensorDictSequential( net, extractor, ProbabilisticTensorDictModule( in_keys=["loc", "scale"], out_keys=["action"], distribution_class=dist.Normal, return_log_prob=True, ), ) print(f"TensorDict before going through module: {td}") td_module(td) print(f"TensorDict after going through module now as keys action, loc and scale: {td}") .. rst-class:: sphx-glr-script-out .. code-block:: none TensorDict before going through module: TensorDict( fields={ hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False), input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False) TensorDict after going through module now as keys action, loc and scale: TensorDict( fields={ action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False), input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), sample_log_prob: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 215-229 Showcase: Implementing a transformer using TensorDictModule ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ To demonstrate the flexibility of ``TensorDictModule``, we are going to create a transformer that reads :class:`~.TensorDict` objects using ``TensorDictModule``. The following figure shows the classical transformer architecture (Vaswani et al, 2017). .. image:: /reference/generated/tutorials/media/transformer.png :alt: The transformer png We have let the positional encoders aside for simplicity. Let's re-write the classical transformers blocks: .. GENERATED FROM PYTHON SOURCE LINES 229-394 .. code-block:: Python class TokensToQKV(nn.Module): def __init__(self, to_dim, from_dim, latent_dim): super().__init__() self.q = nn.Linear(to_dim, latent_dim) self.k = nn.Linear(from_dim, latent_dim) self.v = nn.Linear(from_dim, latent_dim) def forward(self, X_to, X_from): Q = self.q(X_to) K = self.k(X_from) V = self.v(X_from) return Q, K, V class SplitHeads(nn.Module): def __init__(self, num_heads): super().__init__() self.num_heads = num_heads def forward(self, Q, K, V): batch_size, to_num, latent_dim = Q.shape _, from_num, _ = K.shape d_tensor = latent_dim // self.num_heads Q = Q.reshape(batch_size, to_num, self.num_heads, d_tensor).transpose(1, 2) K = K.reshape(batch_size, from_num, self.num_heads, d_tensor).transpose(1, 2) V = V.reshape(batch_size, from_num, self.num_heads, d_tensor).transpose(1, 2) return Q, K, V class Attention(nn.Module): def __init__(self, latent_dim, to_dim): super().__init__() self.softmax = nn.Softmax(dim=-1) self.out = nn.Linear(latent_dim, to_dim) def forward(self, Q, K, V): batch_size, n_heads, to_num, d_in = Q.shape attn = self.softmax(Q @ K.transpose(2, 3) / d_in) out = attn @ V out = self.out(out.transpose(1, 2).reshape(batch_size, to_num, n_heads * d_in)) return out, attn class SkipLayerNorm(nn.Module): def __init__(self, to_len, to_dim): super().__init__() self.layer_norm = nn.LayerNorm((to_len, to_dim)) def forward(self, x_0, x_1): return self.layer_norm(x_0 + x_1) class FFN(nn.Module): def __init__(self, to_dim, hidden_dim, dropout_rate=0.2): super().__init__() self.FFN = nn.Sequential( nn.Linear(to_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, to_dim), nn.Dropout(dropout_rate), ) def forward(self, X): return self.FFN(X) class AttentionBlock(nn.Module): def __init__(self, to_dim, to_len, from_dim, latent_dim, num_heads): super().__init__() self.tokens_to_qkv = TokensToQKV(to_dim, from_dim, latent_dim) self.split_heads = SplitHeads(num_heads) self.attention = Attention(latent_dim, to_dim) self.skip = SkipLayerNorm(to_len, to_dim) def forward(self, X_to, X_from): Q, K, V = self.tokens_to_qkv(X_to, X_from) Q, K, V = self.split_heads(Q, K, V) out, attention = self.attention(Q, K, V) out = self.skip(X_to, out) return out class EncoderTransformerBlock(nn.Module): def __init__(self, to_dim, to_len, latent_dim, num_heads): super().__init__() self.attention_block = AttentionBlock( to_dim, to_len, to_dim, latent_dim, num_heads ) self.FFN = FFN(to_dim, 4 * to_dim) self.skip = SkipLayerNorm(to_len, to_dim) def forward(self, X_to): X_to = self.attention_block(X_to, X_to) X_out = self.FFN(X_to) return self.skip(X_out, X_to) class DecoderTransformerBlock(nn.Module): def __init__(self, to_dim, to_len, from_dim, latent_dim, num_heads): super().__init__() self.attention_block = AttentionBlock( to_dim, to_len, from_dim, latent_dim, num_heads ) self.encoder_block = EncoderTransformerBlock( to_dim, to_len, latent_dim, num_heads ) def forward(self, X_to, X_from): X_to = self.attention_block(X_to, X_from) X_to = self.encoder_block(X_to) return X_to class TransformerEncoder(nn.Module): def __init__(self, num_blocks, to_dim, to_len, latent_dim, num_heads): super().__init__() self.encoder = nn.ModuleList( [ EncoderTransformerBlock(to_dim, to_len, latent_dim, num_heads) for i in range(num_blocks) ] ) def forward(self, X_to): for i in range(len(self.encoder)): X_to = self.encoder[i](X_to) return X_to class TransformerDecoder(nn.Module): def __init__(self, num_blocks, to_dim, to_len, from_dim, latent_dim, num_heads): super().__init__() self.decoder = nn.ModuleList( [ DecoderTransformerBlock(to_dim, to_len, from_dim, latent_dim, num_heads) for i in range(num_blocks) ] ) def forward(self, X_to, X_from): for i in range(len(self.decoder)): X_to = self.decoder[i](X_to, X_from) return X_to class Transformer(nn.Module): def __init__( self, num_blocks, to_dim, to_len, from_dim, from_len, latent_dim, num_heads ): super().__init__() self.encoder = TransformerEncoder( num_blocks, to_dim, to_len, latent_dim, num_heads ) self.decoder = TransformerDecoder( num_blocks, from_dim, from_len, to_dim, latent_dim, num_heads ) def forward(self, X_to, X_from): X_to = self.encoder(X_to) X_out = self.decoder(X_from, X_to) return X_out .. GENERATED FROM PYTHON SOURCE LINES 395-403 We first create the ``AttentionBlockTensorDict``, the attention block using ``TensorDictModule`` and ``TensorDictSequential``. The wiring operation that connects the modules to each other requires us to indicate which key each of them must read and write. Unlike ``nn.Sequence``, a ``TensorDictSequential`` can read/write more than one input/output. Moreover, its components inputs need not be identical to the previous layers outputs, allowing us to code complicated neural architecture. .. GENERATED FROM PYTHON SOURCE LINES 403-440 .. code-block:: Python class AttentionBlockTensorDict(TensorDictSequential): def __init__( self, to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads, ): super().__init__( TensorDictModule( TokensToQKV(to_dim, from_dim, latent_dim), in_keys=[to_name, from_name], out_keys=["Q", "K", "V"], ), TensorDictModule( SplitHeads(num_heads), in_keys=["Q", "K", "V"], out_keys=["Q", "K", "V"], ), TensorDictModule( Attention(latent_dim, to_dim), in_keys=["Q", "K", "V"], out_keys=["X_out", "Attn"], ), TensorDictModule( SkipLayerNorm(to_len, to_dim), in_keys=[to_name, "X_out"], out_keys=[to_name], ), ) .. GENERATED FROM PYTHON SOURCE LINES 441-443 We build the encoder and decoder blocks that will be part of the transformer thanks to ``TensorDictModule``. .. GENERATED FROM PYTHON SOURCE LINES 443-512 .. code-block:: Python class TransformerBlockEncoderTensorDict(TensorDictSequential): def __init__( self, to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads, ): super().__init__( AttentionBlockTensorDict( to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads, ), TensorDictModule( FFN(to_dim, 4 * to_dim), in_keys=[to_name], out_keys=["X_out"], ), TensorDictModule( SkipLayerNorm(to_len, to_dim), in_keys=[to_name, "X_out"], out_keys=[to_name], ), ) class TransformerBlockDecoderTensorDict(TensorDictSequential): def __init__( self, to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads, ): super().__init__( AttentionBlockTensorDict( to_name, to_name, to_dim, to_len, to_dim, latent_dim, num_heads, ), TransformerBlockEncoderTensorDict( to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads, ), ) .. GENERATED FROM PYTHON SOURCE LINES 513-520 We create the transformer encoder and decoder. For an encoder, we just need to take the same tokens for both queries, keys and values. For a decoder, we now can extract info from ``X_from`` into ``X_to``. ``X_from`` will map to queries whereas ``X_from`` will map to keys and values. .. GENERATED FROM PYTHON SOURCE LINES 520-615 .. code-block:: Python class TransformerEncoderTensorDict(TensorDictSequential): def __init__( self, num_blocks, to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads, ): super().__init__( *[ TransformerBlockEncoderTensorDict( to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads, ) for _ in range(num_blocks) ] ) class TransformerDecoderTensorDict(TensorDictSequential): def __init__( self, num_blocks, to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads, ): super().__init__( *[ TransformerBlockDecoderTensorDict( to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads, ) for _ in range(num_blocks) ] ) class TransformerTensorDict(TensorDictSequential): def __init__( self, num_blocks, to_name, from_name, to_dim, to_len, from_dim, from_len, latent_dim, num_heads, ): super().__init__( TransformerEncoderTensorDict( num_blocks, to_name, to_name, to_dim, to_len, to_dim, latent_dim, num_heads, ), TransformerDecoderTensorDict( num_blocks, from_name, to_name, from_dim, from_len, to_dim, latent_dim, num_heads, ), ) .. GENERATED FROM PYTHON SOURCE LINES 616-617 We now test our new ``TransformerTensorDict``. .. GENERATED FROM PYTHON SOURCE LINES 617-650 .. code-block:: Python to_dim = 5 from_dim = 6 latent_dim = 10 to_len = 3 from_len = 10 batch_size = 8 num_heads = 2 num_blocks = 6 tokens = TensorDict( { "X_encode": torch.randn(batch_size, to_len, to_dim), "X_decode": torch.randn(batch_size, from_len, from_dim), }, batch_size=[batch_size], ) transformer = TransformerTensorDict( num_blocks, "X_encode", "X_decode", to_dim, to_len, from_dim, from_len, latent_dim, num_heads, ) transformer(tokens) tokens .. rst-class:: sphx-glr-script-out .. code-block:: none TensorDict( fields={ Attn: Tensor(shape=torch.Size([8, 2, 10, 3]), device=cpu, dtype=torch.float32, is_shared=False), K: Tensor(shape=torch.Size([8, 2, 3, 5]), device=cpu, dtype=torch.float32, is_shared=False), Q: Tensor(shape=torch.Size([8, 2, 10, 5]), device=cpu, dtype=torch.float32, is_shared=False), V: Tensor(shape=torch.Size([8, 2, 3, 5]), device=cpu, dtype=torch.float32, is_shared=False), X_decode: Tensor(shape=torch.Size([8, 10, 6]), device=cpu, dtype=torch.float32, is_shared=False), X_encode: Tensor(shape=torch.Size([8, 3, 5]), device=cpu, dtype=torch.float32, is_shared=False), X_out: Tensor(shape=torch.Size([8, 10, 6]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([8]), device=None, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 651-657 We've achieved to create a transformer with ``TensorDictModule``. This shows that ``TensorDictModule`` is a flexible module that can implement complex operarations. Benchmarking ------------------------------ .. GENERATED FROM PYTHON SOURCE LINES 659-669 .. code-block:: Python to_dim = 5 from_dim = 6 latent_dim = 10 to_len = 3 from_len = 10 batch_size = 8 num_heads = 2 num_blocks = 6 .. GENERATED FROM PYTHON SOURCE LINES 670-679 .. code-block:: Python td_tokens = TensorDict( { "X_encode": torch.randn(batch_size, to_len, to_dim), "X_decode": torch.randn(batch_size, from_len, from_dim), }, batch_size=[batch_size], ) .. GENERATED FROM PYTHON SOURCE LINES 680-684 .. code-block:: Python X_encode = torch.randn(batch_size, to_len, to_dim) X_decode = torch.randn(batch_size, from_len, from_dim) .. GENERATED FROM PYTHON SOURCE LINES 685-698 .. code-block:: Python tdtransformer = TransformerTensorDict( num_blocks, "X_encode", "X_decode", to_dim, to_len, from_dim, from_len, latent_dim, num_heads, ) .. GENERATED FROM PYTHON SOURCE LINES 699-704 .. code-block:: Python transformer = Transformer( num_blocks, to_dim, to_len, from_dim, from_len, latent_dim, num_heads ) .. GENERATED FROM PYTHON SOURCE LINES 705-706 **Inference Time** .. GENERATED FROM PYTHON SOURCE LINES 706-709 .. code-block:: Python import time .. GENERATED FROM PYTHON SOURCE LINES 710-716 .. code-block:: Python t1 = time.time() tokens = tdtransformer(td_tokens) t2 = time.time() print("Execution time:", t2 - t1, "seconds") .. rst-class:: sphx-glr-script-out .. code-block:: none Execution time: 0.009625911712646484 seconds .. GENERATED FROM PYTHON SOURCE LINES 717-723 .. code-block:: Python t3 = time.time() X_out = transformer(X_encode, X_decode) t4 = time.time() print("Execution time:", t4 - t3, "seconds") .. rst-class:: sphx-glr-script-out .. code-block:: none Execution time: 0.006480216979980469 seconds .. GENERATED FROM PYTHON SOURCE LINES 724-728 We can see on this minimal example that the overhead introduced by ``TensorDictModule`` is marginal. Have fun with TensorDictModule! .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 10.088 seconds) .. _sphx_glr_download_tutorials_tensordict_module.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: tensordict_module.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: tensordict_module.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: tensordict_module.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_