TensorDictModule¶
- class tensordict.nn.TensorDictModule(*args, **kwargs)¶
 A TensorDictModule, is a python wrapper around a
nn.Modulethat reads and writes to a TensorDict.- Parameters:
 module (Callable[[Any], Any]) – a callable, typically a
torch.nn.Module, used to map the input to the output parameter space. Its forward method can return a single tensor, a tuple of tensors or even a dictionary. In the latter case, the output keys of theTensorDictModulewill be used to populate the output tensordict (ie. the keys present inout_keysshould be present in the dictionary returned by themoduleforward method).in_keys (iterable of NestedKeys, Dict[NestedStr, str]) – keys to be read from input tensordict and passed to the module. If it contains more than one element, the values will be passed in the order given by the in_keys iterable. If
in_keysis a dictionary, its keys must correspond to the key to be read in the tensordict and its values must match the name of the keyword argument in the function signature. If out_to_in_map isTrue, the mapping gets inverted so that the keys correspond to the keyword arguments in the function signature.out_keys (iterable of str) – keys to be written to the input tensordict. The length of out_keys must match the number of tensors returned by the embedded module. Using “_” as a key avoid writing tensor to output.
- Keyword Arguments:
 out_to_in_map (bool, optional) – if
True(default), in_keys is read as if the keys are the arguments keys of theforward()method and the values are the keys in the inputTensorDict. IfFalse, keys are considered to be the input keys and values the method’s arguments keys.inplace (bool or string, optional) –
if
True(default), the output of the module are written in the tensordict provided to theforward()method. IfFalse, a newTensorDictwith and empty batch-size and no device is created. if"empty",empty()will be used to create the output tensordict.Note
If
inplace=Falseand the tensordict passed to the module is anotherTensorDictBasesubclass thanTensorDict, the output will still be aTensorDictinstance. Its batch-size will be empty, and it will have no device. Set to"empty"to get the sameTensorDictBasesubtype, an identical batch-size and device. Usetensordict_outat runtime (see below) to have a more fine-grained control over the output.Note
If
inplace=Falseand a tensordict_out is passed to theforward()method, thetensordict_outwill prevail. This is the way one can get a tensordict_out taensordict passed to the module is anotherTensorDictBasesubclass thanTensorDict, the output will still be aTensorDictinstance.method (str, optional) – the method to be called in the module, if any. Defaults to __call__.
method_kwargs (Dict[str, Any], optional) – additional keyword arguments to be passed to the module’s method being called.
strict (bool, optional) – if
True, the module will raise an exception if any of the inputs is missing from the input tensordict. Otherwise, a None value will be used as placeholder. Defaults toFalse.get_kwargs (dict[str, Any], optional) – additional keyword arguments to be passed to the
get()method. This is particularily useful when dealing with ragged tensors (seeget()). Defaults to{}.
Embedding a neural network in a TensorDictModule only requires to specify the input and output keys. TensorDictModule support functional and regular
nn.Moduleobjects. In the functional case, the ‘params’ (and ‘buffers’) keyword argument must be specified:Examples
>>> from tensordict import TensorDict >>> # one can wrap regular nn.Module >>> module = TensorDictModule(nn.Transformer(128), in_keys=["input", "tgt"], out_keys=["out"]) >>> input = torch.ones(2, 3, 128) >>> tgt = torch.zeros(2, 3, 128) >>> data = TensorDict({"input": input, "tgt": tgt}, batch_size=[2, 3]) >>> data = module(data) >>> print(data) TensorDict( fields={ input: Tensor(shape=torch.Size([2, 3, 128]), device=cpu, dtype=torch.float32, is_shared=False), out: Tensor(shape=torch.Size([2, 3, 128]), device=cpu, dtype=torch.float32, is_shared=False), tgt: Tensor(shape=torch.Size([2, 3, 128]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([2, 3]), device=None, is_shared=False)
We can also pass directly the tensors
Examples
>>> out = module(input, tgt) >>> assert out.shape == input.shape >>> # we can also wrap regular functions >>> module = TensorDictModule(lambda x: (x-1, x+1), in_keys=[("input", "x")], out_keys=[("output", "x-1"), ("output", "x+1")]) >>> module(TensorDict({("input", "x"): torch.zeros(())}, batch_size=[])) TensorDict( fields={ input: TensorDict( fields={ x: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False), output: TensorDict( fields={ x+1: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), x-1: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
We can use TensorDictModule to populate a tensordict:
Examples
>>> module = TensorDictModule(lambda: torch.randn(3), in_keys=[], out_keys=["x"]) >>> print(module(TensorDict({}, batch_size=[]))) TensorDict( fields={ x: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
Another feature is passing a dictionary as input keys, to control the dispatching of values to specific keyword arguments.
Examples
>>> module = TensorDictModule(lambda x, *, y: x+y, ... in_keys={'1': 'x', '2': 'y'}, out_keys=['z'], out_to_in_map=False ... ) >>> td = module(TensorDict({'1': torch.ones(()), '2': torch.ones(())*2}, [])) >>> td['z'] tensor(3.)
If out_to_in_map is set to
True, then the in_keys mapping is reversed. This way, one can use the same input key for different keyword arguments.Examples
>>> module = TensorDictModule(lambda x, *, y, z: x+y+z, ... in_keys={'x': '1', 'y': '2', z: '2'}, out_keys=['t'], out_to_in_map=True ... ) >>> td = module(TensorDict({'1': torch.ones(()), '2': torch.ones(())*2}, [])) >>> td['t'] tensor(5.)
We can specify the method to be called within a module. Compared to using a lambda function or similar around the module’s method, this has the advantage that the module attributes (params, buffers, submodules) will be exposed.
Examples
>>> from tensordict import TensorDict >>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod >>> from torch import nn >>> import torch >>> >>> class MyNet(nn.Module): ... def my_func(self, tensor: torch.Tensor, *, an_integer: int): ... return tensor + an_integer ... >>> s = Seq( ... { ... "a": lambda td: td+1, ... "b": lambda td: td * 2, ... "c": Mod(MyNet(), in_keys=["a"], out_keys=["b"], method="my_func", method_kwargs={"an_integer": 2}), ... } ... ) >>> td = s(TensorDict(a=0)) >>> print(td) >>> >>> assert td["b"] == 4
Functional calls to a tensordict module is easy:
Examples
>>> import torch >>> from tensordict import TensorDict >>> from tensordict.nn import TensorDictModule >>> td = TensorDict({"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3,]) >>> module = torch.nn.GRUCell(4, 8) >>> td_module = TensorDictModule( ... module=module, in_keys=["input", "hidden"], out_keys=["output"] ... ) >>> params = TensorDict.from_module(td_module) >>> # functional API >>> with params.to_module(td_module): ... td_functional = td_module(td.clone()) >>> print(td_functional) TensorDict( fields={ hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False), input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), output: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False)
- In the stateful case:
 >>> module = torch.nn.GRUCell(4, 8) >>> td_module = TensorDictModule( ... module=module, in_keys=["input", "hidden"], out_keys=["output"] ... ) >>> td_stateful = td_module(td.clone()) >>> print(td_stateful) TensorDict( fields={ hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False), input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), output: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False)
- forward(tensordict: TensorDictBase = None, args=None, *, tensordict_out: tensordict.base.TensorDictBase | None = None, **kwargs: Any) TensorDictBase¶
 When the tensordict parameter is not set, kwargs are used to create an instance of TensorDict.