Shortcuts

DistributionalDQNLoss

class torchrl.objectives.DistributionalDQNLoss(*args, **kwargs)[source]

A distributional DQN loss class.

Distributional DQN uses a value network that outputs a distribution of values over a discrete support of discounted returns (unlike regular DQN where the value network outputs a single point prediction of the disctounted return).

For more details regarding Distributional DQN, refer to “A Distributional Perspective on Reinforcement Learning”, https://arxiv.org/pdf/1707.06887.pdf

Parameters:
  • value_network (DistributionalQValueActor or nn.Module) – the distributional Q value operator.

  • gamma (scalar) –

    a discount factor for return computation.

    Note

    Unlike DQNLoss, this class does not currently support custom value functions. The next value estimation is always bootstrapped.

  • delay_value (bool) – whether to duplicate the value network into a new target value network to create double DQN

  • priority_key (str, optional) – [Deprecated, use .set_keys(priority_key=priority_key) instead] The key at which priority is assumed to be stored within TensorDicts added to this ReplayBuffer. This is to be used when the sampler is of type PrioritizedSampler. Defaults to "td_error".

  • reduction (str, optional) – Specifies the reduction to apply to the output: "none" | "mean" | "sum". "none": no reduction will be applied, "mean": the sum of the output will be divided by the number of elements in the output, "sum": the output will be summed. Default: "mean".

default_keys

alias of _AcceptedKeys

forward(input_tensordict: TensorDictBase) TensorDict[source]

It is designed to read an input TensorDict and return another tensordict with loss keys named “loss*”.

Splitting the loss in its component can then be used by the trainer to log the various loss values throughout training. Other scalars present in the output tensordict will be logged too.

Parameters:

tensordict – an input tensordict with the values required to compute the loss.

Returns:

A new tensordict with no batch dimension containing various loss scalars which will be named “loss*”. It is essential that the losses are returned with this name as they will be read by the trainer before backpropagation.

make_value_estimator(value_type: ValueEstimators = None, **hyperparams)[source]

Value-function constructor.

If the non-default value function is wanted, it must be built using this method.

Parameters:
  • value_type (ValueEstimators, ValueEstimatorBase, or type) –

    The value estimator to use. This can be one of the following:

    • A ValueEstimators enum type indicating which value function to use. If none is provided, the default stored in the default_value_estimator attribute will be used.

    • A ValueEstimatorBase instance, which will be used directly as the value estimator.

    • A ValueEstimatorBase subclass, which will be instantiated with the provided hyperparams.

    The resulting value estimator class will be registered in self.value_type, allowing future refinements.

  • **hyperparams – hyperparameters to use for the value function. If not provided, the value indicated by default_value_kwargs() will be used. When passing a ValueEstimatorBase subclass, these hyperparameters are passed directly to the class constructor.

Returns:

Returns the loss module for method chaining.

Return type:

self

Examples

>>> from torchrl.objectives import DQNLoss
>>> # initialize the DQN loss
>>> actor = torch.nn.Linear(3, 4)
>>> dqn_loss = DQNLoss(actor, action_space="one-hot")
>>> # updating the parameters of the default value estimator
>>> dqn_loss.make_value_estimator(gamma=0.9)
>>> dqn_loss.make_value_estimator(
...     ValueEstimators.TD1,
...     gamma=0.9)
>>> # if we want to change the gamma value
>>> dqn_loss.make_value_estimator(dqn_loss.value_type, gamma=0.9)

Using a ValueEstimatorBase subclass:

>>> from torchrl.objectives.value import TD0Estimator
>>> dqn_loss.make_value_estimator(TD0Estimator, gamma=0.99, value_network=value_net)

Using a ValueEstimatorBase instance:

>>> from torchrl.objectives.value import GAE
>>> gae = GAE(gamma=0.99, lmbda=0.95, value_network=value_net)
>>> ppo_loss.make_value_estimator(gae)

Docs

Lorem ipsum dolor sit amet, consectetur

View Docs

Tutorials

Lorem ipsum dolor sit amet, consectetur

View Tutorials

Resources

Lorem ipsum dolor sit amet, consectetur

View Resources