Shortcuts

Source code for tensordict.nn.distributions.continuous

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from numbers import Number
from typing import Sequence

import numpy as np

import torch

from tensordict.nn.utils import mappings
from torch import distributions as D, nn

# We need this to build the distribution maps
__all__ = [
    "NormalParamExtractor",
    "AddStateIndependentNormalScale",
    "Delta",
]

# speeds up distribution construction
# D.Distribution.set_default_validate_args(False)


class NormalParamWrapper(nn.Module):
    def __init__(
        self,
        operator: nn.Module,
        scale_mapping: str = "biased_softplus_1.0",
        scale_lb: Number = 1e-4,
    ) -> None:
        raise RuntimeError(
            "NormalParamWrapper has been deprecated in favor of `tensordict.nn.NormalParamExtractor`. Use this class instead."
        )


[docs]class NormalParamExtractor(nn.Module): """A non-parametric nn.Module that splits its input into loc and scale parameters. The scale parameters are mapped onto positive values using the specified ``scale_mapping``. Args: scale_mapping (str, optional): positive mapping function to be used with the std. default = ``"biased_softplus_1.0"`` (i.e. softplus map with bias such that fn(0.0) = 1.0) choices: ``"softplus"``, ``"exp"``, ``"relu"``, ``"biased_softplus_1"`` or ``"none"`` (no mapping). See :func:`~tensordict.nn.mappings` for more details. scale_lb (Number, optional): The minimum value that the variance can take. Default is 1e-4. Examples: >>> import torch >>> from tensordict.nn.distributions import NormalParamExtractor >>> from torch import nn >>> module = nn.Linear(3, 4) >>> normal_params = NormalParamExtractor() >>> tensor = torch.randn(3) >>> loc, scale = normal_params(module(tensor)) >>> print(loc.shape, scale.shape) torch.Size([2]) torch.Size([2]) >>> assert (scale > 0).all() >>> # with modules that return more than one tensor >>> module = nn.LSTM(3, 4) >>> tensor = torch.randn(4, 2, 3) >>> loc, scale, others = normal_params(*module(tensor)) >>> print(loc.shape, scale.shape) torch.Size([4, 2, 2]) torch.Size([4, 2, 2]) >>> assert (scale > 0).all() """ def __init__( self, scale_mapping: str = "biased_softplus_1.0", scale_lb: Number = 1e-4, ) -> None: super().__init__() self.scale_mapping = mappings(scale_mapping) self.scale_lb = scale_lb
[docs] def forward(self, *tensors: torch.Tensor) -> tuple[torch.Tensor, ...]: tensor, *others = tensors loc, scale = tensor.chunk(2, -1) scale = self.scale_mapping(scale).clamp_min(self.scale_lb) return (loc, scale, *others)
class AddStateIndependentNormalScale(torch.nn.Module): """A nn.Module that adds trainable state-independent scale parameters. The scale parameters are mapped onto positive values using the specified ``scale_mapping``. Args: scale_shape (torch.Size or equivalent, optional): the shape of the scale parameter. Defaults to ``torch.Size(())``. Keyword Args: scale_mapping (str, optional): positive mapping function to be used with the std. Defaults to ``"exp"``, choices: ``"softplus"``, ``"exp"``, ``"relu"``, ``"biased_softplus_1"``. scale_lb (Number, optional): The minimum value that the variance can take. Defaults to ``1e-4``. device (torch.device, optional): the device of the module. make_param (bool, optional): whether the scale should be a parameter (``True``) or a buffer (``False``). Defaults to ``True``. init_value (float, optional): Initial value of state independent scale. Defaults to 0.0. Examples: >>> from torch import nn >>> import torch >>> num_outputs = 4 >>> module = nn.Linear(3, num_outputs) >>> module_normal = AddStateIndependentNormalScale(num_outputs) >>> tensor = torch.randn(3) >>> loc, scale = module_normal(module(tensor)) >>> print(loc.shape, scale.shape) torch.Size([4]) torch.Size([4]) >>> assert (scale > 0).all() >>> # with modules that return more than one tensor >>> module = nn.LSTM(3, num_outputs) >>> module_normal = AddStateIndependentNormalScale(num_outputs) >>> tensor = torch.randn(4, 2, 3) >>> loc, scale, others = module_normal(*module(tensor)) >>> print(loc.shape, scale.shape) torch.Size([4, 2, 4]) torch.Size([4, 2, 4]) >>> assert (scale > 0).all() """ def __init__( self, scale_shape: torch.Size | int | tuple = None, *, scale_mapping: str = "exp", scale_lb: Number = 1e-4, device: torch.device | None = None, make_param: bool = True, init_value: float = 0.0, ) -> None: super().__init__() if scale_shape is None: scale_shape = torch.Size(()) self.scale_lb = scale_lb if isinstance(scale_shape, int): scale_shape = (scale_shape,) self.scale_shape = torch.Size(scale_shape) self.scale_mapping = scale_mapping if make_param: self.state_independent_scale = torch.nn.Parameter( torch.full( scale_shape, init_value, dtype=torch.get_default_dtype(), device=device, ) ) else: self.state_independent_scale = torch.nn.Buffer( torch.full( scale_shape, init_value, dtype=torch.get_default_dtype(), device=device, ) ) def forward( self, loc: torch.Tensor, *others: torch.Tensor ) -> tuple[torch.Tensor, ...]: """Forward of AddStateIndependentNormalScale. Args: loc (torch.Tensor): a location parameter. *others: other unused parameters. Returns: a tuple of two or more tensors containing the ``(loc, scale, *others)`` values. """ if self.scale_shape != loc.shape[-len(self.scale_shape) :]: raise RuntimeError( f"Last dimensions of loc ({loc.shape[-len(self.scale_shape):]}) do not match the number of dimensions " f"in scale ({self.state_independent_scale.shape})" ) scale = self.state_independent_scale.expand_as(loc) scale = mappings(self.scale_mapping)(scale).clamp_min(self.scale_lb) return (loc, scale, *others) class Delta(D.Distribution): """Delta distribution. Args: param (torch.Tensor): parameter of the delta distribution; atol (number, optional): absolute tolerance to consider that a tensor matches the distribution parameter; Default is 1e-6 rtol (number, optional): relative tolerance to consider that a tensor matches the distribution parameter; Default is 1e-6 batch_shape (torch.Size, optional): batch shape; event_shape (torch.Size, optional): shape of the outcome. """ arg_constraints: dict = {} def __init__( self, param: torch.Tensor, atol: float = 1e-6, rtol: float = 1e-6, batch_shape: torch.Size | Sequence[int] | None = None, event_shape: torch.Size | Sequence[int] | None = None, ) -> None: if batch_shape is None: batch_shape = torch.Size([]) if event_shape is None: event_shape = torch.Size([]) self.update(param) self.atol = atol self.rtol = rtol if not len(batch_shape) and not len(event_shape): batch_shape = param.shape[:-1] event_shape = param.shape[-1:] super().__init__(batch_shape=batch_shape, event_shape=event_shape) def update(self, param: torch.Tensor) -> None: self.param = param def _is_equal(self, value: torch.Tensor) -> torch.Tensor: param = self.param.expand_as(value) is_equal = abs(value - param) < self.atol + self.rtol * abs(param) for i in range(-1, -len(self.event_shape) - 1, -1): is_equal = is_equal.all(i) return is_equal def log_prob(self, value: torch.Tensor) -> torch.Tensor: is_equal = self._is_equal(value) out = torch.zeros_like(is_equal, dtype=value.dtype) out.masked_fill_(is_equal, np.inf) out.masked_fill_(~is_equal, -np.inf) return out @torch.no_grad() def sample( self, sample_shape: torch.Size | Sequence[int] | None = None, ) -> torch.Tensor: if sample_shape is None: sample_shape = torch.Size([]) return self.param.expand((*sample_shape, *self.param.shape)) def rsample( self, sample_shape: torch.Size | Sequence[int] | None = None, ) -> torch.Tensor: if sample_shape is None: sample_shape = torch.Size([]) return self.param.expand((*sample_shape, *self.param.shape)) @property def mode(self) -> torch.Tensor: return self.param @property def mean(self) -> torch.Tensor: return self.param @property def deterministic_sample(self) -> torch.Tensor: return self.param @property def _logistic_deterministic_sample(self): s = self.loc for t in self.transforms: s = t(s) return s D.LogisticNormal.deterministic_sample = _logistic_deterministic_sample

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources