OneHotCategorical

class torchrl.modules.OneHotCategorical(logits: Tensor | None = None, probs: Tensor | None = None, grad_method: ReparamGradientStrategy = ReparamGradientStrategy.PassThrough, **kwargs)[source]

One-hot categorical distribution.

This class behaves exactly as torch.distributions.Categorical except that it reads and produces one-hot encodings of the discrete tensors.

Parameters:
  • logits (torch.Tensor) – event log probabilities (unnormalized)

  • probs (torch.Tensor) – event probabilities

  • grad_method (ReparamGradientStrategy, optional) – strategy to gather reparameterized samples. ReparamGradientStrategy.PassThrough will compute the sample gradients by using the softmax valued log-probability as a proxy to the sample gradients. ReparamGradientStrategy.RelaxedOneHot will use torch.distributions.RelaxedOneHot to sample from the distribution.

Examples

>>> torch.manual_seed(0)
>>> logits = torch.randn(4)
>>> dist = OneHotCategorical(logits=logits)
>>> print(dist.rsample((3,)))
tensor([[1., 0., 0., 0.],
        [0., 0., 0., 1.],
        [1., 0., 0., 0.]])
entropy()[source]

Returns entropy of distribution, batched over batch_shape.

Returns:

Tensor of shape batch_shape.

log_prob(value: Tensor) Tensor[source]

Returns the log of the probability density/mass function evaluated at value.

Parameters:

value (Tensor)

property mode: Tensor

Returns the mode of the distribution.

rsample(sample_shape: Size | Sequence = None) Tensor[source]

Generates a sample_shape shaped reparameterized sample or sample_shape shaped batch of reparameterized samples if the distribution parameters are batched.

sample(sample_shape: Size | Sequence | None = None) Tensor[source]

Generates a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched.

Docs

Lorem ipsum dolor sit amet, consectetur

View Docs

Tutorials

Lorem ipsum dolor sit amet, consectetur

View Tutorials

Resources

Lorem ipsum dolor sit amet, consectetur

View Resources