DQNLoss¶
- class torchrl.objectives.DQNLoss(*args, **kwargs)[source]¶
The DQN Loss class.
- Parameters:
value_network (QValueActor or nn.Module) – a Q value operator.
- Keyword Arguments:
loss_function (str, optional) – loss function for the value discrepancy. Can be one of “l1”, “l2” or “smooth_l1”. Defaults to “l2”.
delay_value (bool, optional) – whether to duplicate the value network into a new target value network to create a DQN with a target network. Default is
True.double_dqn (bool, optional) – whether to use Double DQN, as described in https://arxiv.org/abs/1509.06461. Defaults to
False.action_space (str or TensorSpec, optional) – Action space. Must be one of
"one-hot","mult_one_hot","binary"or"categorical", or an instance of the corresponding specs (torchrl.data.OneHot,torchrl.data.MultiOneHot,torchrl.data.Binaryortorchrl.data.Categorical). If not provided, an attempt to retrieve it from the value network will be made.priority_key (NestedKey, optional) – [Deprecated, use .set_keys(priority_key=priority_key) instead] The key at which priority is assumed to be stored within TensorDicts added to this ReplayBuffer. This is to be used when the sampler is of type
PrioritizedSampler. Defaults to"td_error".reduction (str, optional) – Specifies the reduction to apply to the output:
"none"|"mean"|"sum"."none": no reduction will be applied,"mean": the sum of the output will be divided by the number of elements in the output,"sum": the output will be summed. Default:"mean".
Examples
>>> from torchrl.modules import MLP >>> from torchrl.data import OneHot >>> n_obs, n_act = 4, 3 >>> value_net = MLP(in_features=n_obs, out_features=n_act) >>> spec = OneHot(n_act) >>> actor = QValueActor(value_net, in_keys=["observation"], action_space=spec) >>> loss = DQNLoss(actor, action_space=spec) >>> batch = [10,] >>> data = TensorDict({ ... "observation": torch.randn(*batch, n_obs), ... "action": spec.rand(batch), ... ("next", "observation"): torch.randn(*batch, n_obs), ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool), ... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool), ... ("next", "reward"): torch.randn(*batch, 1) ... }, batch) >>> loss(data) TensorDict( fields={ loss: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
This class is compatible with non-tensordict based modules too and can be used without recurring to any tensordict-related primitive. In this case, the expected keyword arguments are:
["observation", "next_observation", "action", "next_reward", "next_done", "next_terminated"], and a single loss value is returned.Examples
>>> from torchrl.objectives import DQNLoss >>> from torchrl.data import OneHot >>> from torch import nn >>> import torch >>> n_obs = 3 >>> n_action = 4 >>> action_spec = OneHot(n_action) >>> value_network = nn.Linear(n_obs, n_action) # a simple value model >>> dqn_loss = DQNLoss(value_network, action_space=action_spec) >>> # define data >>> observation = torch.randn(n_obs) >>> next_observation = torch.randn(n_obs) >>> action = action_spec.rand() >>> next_reward = torch.randn(1) >>> next_done = torch.zeros(1, dtype=torch.bool) >>> next_terminated = torch.zeros(1, dtype=torch.bool) >>> loss_val = dqn_loss( ... observation=observation, ... next_observation=next_observation, ... next_reward=next_reward, ... next_done=next_done, ... next_terminated=next_terminated, ... action=action)
- default_keys¶
alias of
_AcceptedKeys
- forward(tensordict: TensorDictBase = None) TensorDict[source]¶
Computes the DQN loss given a tensordict sampled from the replay buffer.
- This function will also write a “td_error” key that can be used by prioritized replay buffers to assign
a priority to items in the tensordict.
- Parameters:
tensordict (TensorDictBase) – a tensordict with keys [“action”] and the in_keys of the value network (observations, “done”, “terminated”, “reward” in a “next” tensordict).
- Returns:
a tensor containing the DQN loss.
- make_value_estimator(value_type: ValueEstimators = None, **hyperparams)[source]¶
Value-function constructor.
If the non-default value function is wanted, it must be built using this method.
- Parameters:
value_type (ValueEstimators, ValueEstimatorBase, or type) –
The value estimator to use. This can be one of the following:
A
ValueEstimatorsenum type indicating which value function to use. If none is provided, the default stored in thedefault_value_estimatorattribute will be used.A
ValueEstimatorBaseinstance, which will be used directly as the value estimator.A
ValueEstimatorBasesubclass, which will be instantiated with the providedhyperparams.
The resulting value estimator class will be registered in
self.value_type, allowing future refinements.**hyperparams – hyperparameters to use for the value function. If not provided, the value indicated by
default_value_kwargs()will be used. When passing aValueEstimatorBasesubclass, these hyperparameters are passed directly to the class constructor.
- Returns:
Returns the loss module for method chaining.
- Return type:
self
Examples
>>> from torchrl.objectives import DQNLoss >>> # initialize the DQN loss >>> actor = torch.nn.Linear(3, 4) >>> dqn_loss = DQNLoss(actor, action_space="one-hot") >>> # updating the parameters of the default value estimator >>> dqn_loss.make_value_estimator(gamma=0.9) >>> dqn_loss.make_value_estimator( ... ValueEstimators.TD1, ... gamma=0.9) >>> # if we want to change the gamma value >>> dqn_loss.make_value_estimator(dqn_loss.value_type, gamma=0.9)
Using a
ValueEstimatorBasesubclass:>>> from torchrl.objectives.value import TD0Estimator >>> dqn_loss.make_value_estimator(TD0Estimator, gamma=0.99, value_network=value_net)
Using a
ValueEstimatorBaseinstance:>>> from torchrl.objectives.value import GAE >>> gae = GAE(gamma=0.99, lmbda=0.95, value_network=value_net) >>> ppo_loss.make_value_estimator(gae)