ProbabilisticTensorDictSequential¶
- class tensordict.nn.ProbabilisticTensorDictSequential(*args, **kwargs)¶
A sequence of
TensorDictModulescontaining at least oneProbabilisticTensorDictModule.This class extends
TensorDictSequentialand is typically configured with a sequence of modules where the final module is an instance ofProbabilisticTensorDictModule. However, it also supports configurations where one or more intermediate modules are instances ofProbabilisticTensorDictModule, while the last module may or may not be probabilistic. In all cases, it exposes theget_dist()method to recover the distribution object from theProbabilisticTensorDictModuleinstances in the sequence.Multiple probabilistic modules can co-exist in a single
ProbabilisticTensorDictSequential. If return_composite isFalse(default), only the last one will produce a distribution and the others will be executed as regularTensorDictModuleinstances. However, if a ProbabilisticTensorDictModule is not the last module in the sequence and return_composite=False, a ValueError will be raised when trying to query the module. If return_composite=True, all intermediate ProbabilisticTensorDictModule instances will contribute to a singleCompositeDistributioninstance.Resulting log-probabilities will be conditional probabilities if samples are interdependent: whenever
\[Z = F(X, Y)\]then the log-probability of Z will be
\[log(p(z | x, y))\]- Parameters:
*modules (sequence or OrderedDict of TensorDictModuleBase or ProbabilisticTensorDictModule) – An ordered sequence of
TensorDictModuleinstances, usually terminating in aProbabilisticTensorDictModule, to be run sequentially. The modules can be instances of TensorDictModuleBase or any other function that matches this signature. Note that if a non-TensorDictModuleBase callable is used, its input and output keys will not be tracked, and thus will not affect the in_keys and out_keys attributes of the TensorDictSequential.- Keyword Arguments:
partial_tolerant (bool, optional) – If
True, the input tensordict can miss some of the input keys. If so, only the modules that can be executed given the keys that are present will be executed. Also, if the input tensordict is a lazy stack of tensordicts AND if partial_tolerant isTrueAND if the stack does not have the required keys, then TensorDictSequential will scan through the sub-tensordicts looking for those that have the required keys, if any. Defaults toFalse.return_composite (bool, optional) – If True and multiple
ProbabilisticTensorDictModuleorProbabilisticTensorDictSequentialinstances are found, aCompositeDistributioninstance will be used. Otherwise, only the last module will be used to build the distribution. Defaults toTruewhenever there are more than one probabilistic modules or the last module is not probabilistic. Errors if return_composite is False and the neither of the above conditions are met.inplace (bool, optional) – if True, the input tensordict is modified in-place. If False, a new empty
TensorDictinstance is created. If “empty”, input.empty() is used instead (ie, the output preserves type, device and batch-size). Defaults to None (relies on sub-modules).
- Raises:
ValueError – If the input sequence of modules is empty.
TypeError – If the final module is not an instance of
ProbabilisticTensorDictModuleorProbabilisticTensorDictSequential.
Examples
>>> from tensordict.nn import ProbabilisticTensorDictModule as Prob, ProbabilisticTensorDictSequential as Seq >>> import torch >>> # Typical usage: a single distribution is computed last in the sequence >>> import torch >>> from tensordict import TensorDict >>> from tensordict.nn import ProbabilisticTensorDictModule as Prob, ProbabilisticTensorDictSequential as Seq, ... TensorDictModule as Mod >>> torch.manual_seed(0) >>> >>> module = Seq( ... Mod(lambda x: x + 1, in_keys=["x"], out_keys=["loc"]), ... Prob(in_keys=["loc"], out_keys=["sample"], distribution_class=torch.distributions.Normal, ... distribution_kwargs={"scale": 1}), ... ) >>> input = TensorDict(x=torch.ones(3)) >>> td = module(input.copy()) >>> print(td) TensorDict( fields={ loc: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), sample: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), x: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> print(module.get_dist(input)) Normal(loc: torch.Size([3]), scale: torch.Size([3])) >>> print(module.log_prob(td)) tensor([-0.9189, -0.9189, -0.9189]) >>> # Intermediate distributions are ignored when return_composite=False >>> module = Seq( ... Mod(lambda x: x + 1, in_keys=["x"], out_keys=["loc"]), ... Prob(in_keys=["loc"], out_keys=["sample0"], distribution_class=torch.distributions.Normal, ... distribution_kwargs={"scale": 1}), ... Mod(lambda x: x + 1, in_keys=["sample0"], out_keys=["loc2"]), ... Prob(in_keys={"loc": "loc2"}, out_keys=["sample1"], distribution_class=torch.distributions.Normal, ... distribution_kwargs={"scale": 1}), ... return_composite=False, ... ) >>> td = module(TensorDict(x=torch.ones(3))) >>> print(td) TensorDict( fields={ loc2: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), loc: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), sample0: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), sample1: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), x: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> print(module.get_dist(input)) Normal(loc: torch.Size([3]), scale: torch.Size([3])) >>> print(module.log_prob(td)) tensor([-0.9189, -0.9189, -0.9189]) >>> # Intermediate distributions produce a CompositeDistribution when return_composite=True >>> module = Seq( ... Mod(lambda x: x + 1, in_keys=["x"], out_keys=["loc"]), ... Prob(in_keys=["loc"], out_keys=["sample0"], distribution_class=torch.distributions.Normal, ... distribution_kwargs={"scale": 1}), ... Mod(lambda x: x + 1, in_keys=["sample0"], out_keys=["loc2"]), ... Prob(in_keys={"loc": "loc2"}, out_keys=["sample1"], distribution_class=torch.distributions.Normal, ... distribution_kwargs={"scale": 1}), ... return_composite=True, ... ) >>> input = TensorDict(x=torch.ones(3)) >>> td = module(input.copy()) >>> print(td) TensorDict( fields={ loc2: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), loc: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), sample0: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), sample1: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), x: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> print(module.get_dist(input)) CompositeDistribution({'sample0': Normal(loc: torch.Size([3]), scale: torch.Size([3])), 'sample1': Normal(loc: torch.Size([3]), scale: torch.Size([3]))}) >>> print(module.log_prob(td)) TensorDict( fields={ sample0_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), sample1_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> # Even a single intermediate distribution is wrapped in a CompositeDistribution when >>> # return_composite=True >>> module = Seq( ... Mod(lambda x: x + 1, in_keys=["x"], out_keys=["loc"]), ... Prob(in_keys=["loc"], out_keys=["sample0"], distribution_class=torch.distributions.Normal, ... distribution_kwargs={"scale": 1}), ... Mod(lambda x: x + 1, in_keys=["sample0"], out_keys=["y"]), ... return_composite=True, ... ) >>> td = module(TensorDict(x=torch.ones(3))) >>> print(td) TensorDict( fields={ loc: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), sample0: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), x: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), y: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> print(module.get_dist(input)) CompositeDistribution({'sample0': Normal(loc: torch.Size([3]), scale: torch.Size([3]))}) >>> print(module.log_prob(td)) TensorDict( fields={ sample0_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
- build_dist_from_params(tensordict: TensorDictBase) Distribution¶
Constructs a distribution from the input parameters without evaluating other modules in the sequence.
This method searches for the last
ProbabilisticTensorDictModulein the sequence and uses it to build the distribution.- Parameters:
tensordict (TensorDictBase) – The input tensordict containing the distribution parameters.
- Returns:
The constructed distribution object.
- Return type:
D.Distribution
- Raises:
RuntimeError – If no
ProbabilisticTensorDictModuleis found in the sequence.
- property default_interaction_type¶
Returns the default_interaction_type of the module using an iterative heuristic.
This property iterates over all modules in reverse order, attempting to retrieve the default_interaction_type attribute from any child module. The first non-None value encountered is returned. If no such value is found, a default interaction_type() is returned.
- property dist_params_keys: List[NestedKey]¶
Returns all the keys pointing at the distribution params.
- property dist_sample_keys: List[NestedKey]¶
Returns all the keys pointing at the distribution samples.
- forward(tensordict: TensorDictBase = None, tensordict_out: tensordict.base.TensorDictBase | None = None, **kwargs) TensorDictBase¶
When the tensordict parameter is not set, kwargs are used to create an instance of TensorDict.
- get_dist(tensordict: TensorDictBase, tensordict_out: Optional[TensorDictBase] = None, **kwargs) Distribution¶
Returns the distribution resulting from passing the input tensordict through the sequence.
If return_composite is
False(default), this method will only consider the last probabilistic module in the sequence.Otherwise, it will return a
CompositeDistributioninstance containing the distributions of all probabilistic modules.- Parameters:
tensordict (TensorDictBase) – The input tensordict.
tensordict_out (TensorDictBase, optional) – The output tensordict. If
None, a new tensordict will be created. Defaults toNone.
- Keyword Arguments:
**kwargs – Additional keyword arguments passed to the underlying modules.
- Returns:
The resulting distribution object.
- Return type:
D.Distribution
- Raises:
RuntimeError – If no probabilistic module is found in the sequence.
Note
When return_composite is
True, the distributions are conditioned on the previous samples in the sequence. This means that if a module depends on the output of a previous probabilistic module, its distribution will be conditional.
- get_dist_params(tensordict: TensorDictBase, tensordict_out: Optional[TensorDictBase] = None, **kwargs) tuple[torch.distributions.distribution.Distribution, tensordict.base.TensorDictBase]¶
Returns the distribution parameters and output tensordict.
This method runs the deterministic part of the
ProbabilisticTensorDictSequentialmodule to obtain the distribution parameters. The interaction type is set to the current global interaction type if available, otherwise it defaults to the interaction type of the last module.- Parameters:
tensordict (TensorDictBase) – The input tensordict.
tensordict_out (TensorDictBase, optional) – The output tensordict. If
None, a new tensordict will be created. Defaults toNone.
- Keyword Arguments:
**kwargs – Additional keyword arguments passed to the deterministic part of the module.
- Returns:
A tuple containing the distribution object and the output tensordict.
- Return type:
tuple[D.Distribution, TensorDictBase]
Note
The interaction type is temporarily set to the specified value during the execution of this method.
- log_prob(tensordict, tensordict_out: Optional[TensorDictBase] = None, *, dist: Optional[Distribution] = None, **kwargs) tensordict.base.TensorDictBase | torch.Tensor¶
Returns the log-probability of the input tensordict.
If self.return_composite is
Trueand the distribution is aCompositeDistribution, this method will return the log-probability of the entire composite distribution.Otherwise, it will only consider the last probabilistic module in the sequence.
- Parameters:
tensordict (TensorDictBase) – The input tensordict.
tensordict_out (TensorDictBase, optional) – The output tensordict. If
None, a new tensordict will be created. Defaults toNone.
- Keyword Arguments:
dist (torch.distributions.Distribution, optional) – The distribution object. If
None, it will be computed using get_dist. Defaults toNone.- Returns:
The log-probability of the input tensordict.
- Return type: