TD3BCLoss¶
- class torchrl.objectives.TD3BCLoss(*args, **kwargs)[source]¶
TD3+BC Loss Module.
Implementation of the TD3+BC loss presented in the paper “A Minimalist Approach to Offline Reinforcement Learning” <https://arxiv.org/pdf/2106.06860>.
This class incorporates two loss functions, executed sequentially within the forward method:
Users also have the option to call these functions directly in the same order if preferred.
- Parameters:
actor_network (TensorDictModule) – the actor to be trained
qvalue_network (TensorDictModule) –
a single Q-value network or a list of Q-value networks. If a single instance of qvalue_network is provided, it will be duplicated
num_qvalue_netstimes. If a list of modules is passed, their parameters will be stacked unless they share the same identity (in which case the original parameter will be expanded).Warning
When a list of parameters if passed, it will __not__ be compared against the policy parameters and all the parameters will be considered as untied.
- Keyword Arguments:
bounds (tuple of float, optional) –
- the bounds of the action space.
Exclusive with
action_spec. Either this oraction_specmust
be provided.
action_spec (TensorSpec, optional) – the action spec. Exclusive with
bounds. Either this orboundsmust be provided.num_qvalue_nets (int, optional) – Number of Q-value networks to be trained. Default is
2.policy_noise (
float, optional) – Standard deviation for the target policy action noise. Default is0.2.noise_clip (
float, optional) – Clipping range value for the sampled target policy action noise. Default is0.5.alpha (
float, optional) – Weight for the behavioral cloning loss. Defaults to2.5.priority_key (str, optional) – Key where to write the priority value for prioritized replay buffers. Default is “td_error”.
loss_function (str, optional) – loss function to be used for the Q-value. Can be one of
"smooth_l1","l2","l1", Default is"smooth_l1".delay_actor (bool, optional) – whether to separate the target actor networks from the actor networks used for data collection. Default is
True.delay_qvalue (bool, optional) – Whether to separate the target Q value networks from the Q value networks used for data collection. Default is
True.spec (TensorSpec, optional) – the action tensor spec. If not provided and the target entropy is
"auto", it will be retrieved from the actor.separate_losses (bool, optional) – if
True, shared parameters between policy and critic will only be trained on the policy loss. Defaults toFalse, i.e., gradients are propagated to shared parameters for both policy and critic losses.reduction (str, optional) – Specifies the reduction to apply to the output:
"none"|"mean"|"sum"."none": no reduction will be applied,"mean": the sum of the output will be divided by the number of elements in the output,"sum": the output will be summed. Default:"mean".deactivate_vmap (bool, optional) – whether to deactivate vmap calls and replace them with a plain for loop. Defaults to
False.
Examples
>>> import torch >>> from torch import nn >>> from torchrl.data import Bounded >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import Actor, ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.td3_bc import TD3BCLoss >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> module = nn.Linear(n_obs, n_act) >>> actor = Actor( ... module=module, ... spec=spec) >>> class ValueClass(nn.Module): ... def __init__(self): ... super().__init__() ... self.linear = nn.Linear(n_obs + n_act, 1) ... def forward(self, obs, act): ... return self.linear(torch.cat([obs, act], -1)) >>> module = ValueClass() >>> qvalue = ValueOperator( ... module=module, ... in_keys=['observation', 'action']) >>> loss = TD3BCLoss(actor, qvalue, action_spec=actor.spec) >>> batch = [2, ] >>> action = spec.rand(batch) >>> data = TensorDict({ ... "observation": torch.randn(*batch, n_obs), ... "action": action, ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool), ... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool), ... ("next", "reward"): torch.randn(*batch, 1), ... ("next", "observation"): torch.randn(*batch, n_obs), ... }, batch) >>> loss(data) TensorDict( fields={ bc_loss: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), lmbd: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), loss_qvalue: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), next_state_value: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), pred_value: Tensor(shape=torch.Size([2, 2]), device=cpu, dtype=torch.float32, is_shared=False), state_action_value_actor: Tensor(shape=torch.Size([2, 2]), device=cpu, dtype=torch.float32, is_shared=False), target_value: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
This class is compatible with non-tensordict based modules too and can be used without recurring to any tensordict-related primitive. In this case, the expected keyword arguments are:
["action", "next_reward", "next_done", "next_terminated"]+ in_keys of the actor and qvalue network The return value is a tuple of tensors in the following order:["loss_actor", "loss_qvalue", "bc_loss, "lmbd", "pred_value", "state_action_value_actor", "next_state_value", "target_value",].Examples
>>> import torch >>> from torch import nn >>> from torchrl.data import Bounded >>> from torchrl.modules.tensordict_module.actors import Actor, ValueOperator >>> from torchrl.objectives.td3_bc import TD3BCLoss >>> n_act, n_obs = 4, 3 >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> module = nn.Linear(n_obs, n_act) >>> actor = Actor( ... module=module, ... spec=spec) >>> class ValueClass(nn.Module): ... def __init__(self): ... super().__init__() ... self.linear = nn.Linear(n_obs + n_act, 1) ... def forward(self, obs, act): ... return self.linear(torch.cat([obs, act], -1)) >>> module = ValueClass() >>> qvalue = ValueOperator( ... module=module, ... in_keys=['observation', 'action']) >>> loss = TD3BCLoss(actor, qvalue, action_spec=actor.spec) >>> _ = loss.select_out_keys("loss_actor", "loss_qvalue") >>> batch = [2, ] >>> action = spec.rand(batch) >>> loss_actor, loss_qvalue = loss( ... observation=torch.randn(*batch, n_obs), ... action=action, ... next_done=torch.zeros(*batch, 1, dtype=torch.bool), ... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool), ... next_reward=torch.randn(*batch, 1), ... next_observation=torch.randn(*batch, n_obs)) >>> loss_actor.backward()
- actor_loss(tensordict) tuple[torch.Tensor, dict][source]¶
Compute the actor loss.
The actor loss should be computed after the
qvalue_loss()and is usually delayed 1-3 critic updates.- Parameters:
tensordict (TensorDictBase) – the input data for the loss. Check the class’s in_keys to see what fields are required for this to be computed.
- Returns: a differentiable tensor with the actor loss along with a metadata dictionary containing the detached “bc_loss”
used in the combined actor loss as well as the detached “state_action_value_actor” used to calculate the lambda value, and the lambda value “lmbd” itself.
- default_keys¶
alias of
_AcceptedKeys
- forward(tensordict: TensorDictBase = None) TensorDictBase[source]¶
The forward method.
Computes successively the
actor_loss(),qvalue_loss(), and returns a tensordict with these values. To see what keys are expected in the input tensordict and what keys are expected as output, check the class’s “in_keys” and “out_keys” attributes.
- make_value_estimator(value_type: ValueEstimators = None, **hyperparams)[source]¶
Value-function constructor.
If the non-default value function is wanted, it must be built using this method.
- Parameters:
value_type (ValueEstimators) – A
ValueEstimatorsenum type indicating the value function to use. If none is provided, the default stored in thedefault_value_estimatorattribute will be used. The resulting value estimator class will be registered inself.value_type, allowing future refinements.**hyperparams – hyperparameters to use for the value function. If not provided, the value indicated by
default_value_kwargs()will be used.
Examples
>>> from torchrl.objectives import DQNLoss >>> # initialize the DQN loss >>> actor = torch.nn.Linear(3, 4) >>> dqn_loss = DQNLoss(actor, action_space="one-hot") >>> # updating the parameters of the default value estimator >>> dqn_loss.make_value_estimator(gamma=0.9) >>> dqn_loss.make_value_estimator( ... ValueEstimators.TD1, ... gamma=0.9) >>> # if we want to change the gamma value >>> dqn_loss.make_value_estimator(dqn_loss.value_type, gamma=0.9)
- qvalue_loss(tensordict) tuple[torch.Tensor, dict][source]¶
Compute the q-value loss.
The q-value loss should be computed before the
actor_loss().- Parameters:
tensordict (TensorDictBase) – the input data for the loss. Check the class’s in_keys to see what fields are required for this to be computed.
- Returns: a differentiable tensor with the qvalue loss along with a metadata dictionary containing
the detached “td_error” to be used for prioritized sampling, the detached “next_state_value”, the detached “pred_value”, and the detached “target_value”.