Rate this Page

VTrace#

class torchrl.objectives.value.VTrace(*args, **kwargs)[source]#

A class wrapper around V-Trace estimate functional.

Refer to “IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures” :ref:`here <https://arxiv.org/abs/1802.01561>`_ for more context.

Keyword Arguments:
  • gamma (scalar) – exponential mean discount.

  • value_network (TensorDictModule) – value operator used to retrieve the value estimates.

  • actor_network (TensorDictModule) – actor operator used to retrieve the log prob.

  • rho_thresh (Union[float, Tensor]) – rho clipping parameter for importance weights. Defaults to 1.0.

  • c_thresh (Union[float, Tensor]) – c clipping parameter for importance weights. Defaults to 1.0.

  • average_adv (bool) – if True, the resulting advantage values will be standardized. Default is False.

  • differentiable (bool, optional) –

    if True, gradients are propagated through the computation of the value function. Default is False.

    Note

    The proper way to make the function call non-differentiable is to decorate it in a torch.no_grad() context manager/decorator or pass detached parameters for functional modules.

  • skip_existing (bool, optional) – if True, the value network will skip modules which outputs are already present in the tensordict. Defaults to None, i.e., the value of tensordict.nn.skip_existing() is not affected. Defaults to “state_value”.

  • advantage_key (str or tuple of str, optional) – [Deprecated] the key of the advantage entry. Defaults to "advantage".

  • value_target_key (str or tuple of str, optional) – [Deprecated] the key of the advantage entry. Defaults to "value_target".

  • value_key (str or tuple of str, optional) – [Deprecated] the value key to read from the input tensordict. Defaults to "state_value".

  • shifted (bool, optional) –

    controls how value and next-value are obtained from the value network. False (default) calls the value network twice (once on the root tensordict, once on "next"), which is correct whenever "next" may differ non-trivially from obs[t+1]. Truthy values request a single call:

    • True: fixed-budget single-call path. Inserts the true ("next", <in_key>) entry after every internal truncation (done & ~terminated), shifts subsequent samples to the right inside a sequence of length T + shifted_budget and masks the displaced suffix via "shifted_valid". Terminal steps (done & terminated) do not consume budget. Retained samples use exact next observations.

    Note

    Single-step rollout assumption. shifted=True relies on the standard one-step rollout layout produced by env.step + auto-reset: at every position where done[t] = False, the value-net inputs in ("next", <in_key>)[t] are expected to equal <in_key>[t+1]. The backend uses this invariant to evaluate V once over a fused [T + shifted_budget] sequence instead of twice over [T] streams.

    The canonical pipeline that breaks the invariant is multi-step return processing (MultiStep / n-step bootstrapping), which rewrites ("next", obs)[t] to obs[t+n] with n > 1. shifted=True is unsupported with multi-step returns — use shifted=False instead.

    Single-call paths also require that the parameters at time t and t+1 are identical (i.e. target_params is not used).

    Defaults to False.

  • device (torch.device, optional) – the device where the buffers will be instantiated. Defaults to torch.get_default_device().

  • time_dim (int, optional) – the dimension corresponding to the time in the input tensordict. If not provided, defaults to the dimension marked with the "time" name if any, and to the last dimension otherwise. Can be overridden during a call to value_estimate(). Negative dimensions are considered with respect to the input tensordict.

  • value_chunk_size (int, optional) – if set, splits value-network calls into chunks of this many elements along the leading dimension. Defaults to None.

  • num_chunks (int, optional) – if set, splits value-network calls into this many chunks along the leading dimension. Mutually exclusive with value_chunk_size. num_chunk is accepted as an alias. Defaults to None.

  • num_chunk (int, optional) – alias for num_chunks. Cannot be set together with a different num_chunks value. Defaults to None.

  • shifted_budget (int, optional) – number of extra value-network time slots used when shifted=True. 1 uses a T+1 budget, 2 can represent one internal reset plus the rollout boundary without dropping samples, and so on. Defaults to 1.

VTrace will return an "advantage" entry containing the advantage value. It will also return a "value_target" entry with the V-Trace target value.

Note

As other advantage functions do, if the value_key is already present in the input tensordict, the VTrace module will ignore the calls to the value network (if any) and use the provided value instead.

classmethod for_loss(loss_module, **hyperparams)[source]#

V-Trace needs both the critic and the actor.

When the loss is functional, the actor stored on the loss module is a stateless template — we deep-copy it and bake the current params in, since V-Trace doesn’t support a functional actor call.

forward(tensordict: TensorDictBase = None, *, params: list[Tensor] | None = None, target_params: list[Tensor] | None = None, time_dim: int | None = None) TensorDictBase[source]#

Computes the V-Trace correction given the data in tensordict.

If a functional module is provided, a nested TensorDict containing the parameters (and if relevant the target parameters) can be passed to the module.

Parameters:

tensordict (TensorDictBase) – A TensorDict containing the data (an observation key, “action”, “reward”, “done” and “next” tensordict state as returned by the environment) necessary to compute the value estimates and the GAE. The data passed to this module should be structured as [*B, T, F] where B are the batch size, T the time dimension and F the feature dimension(s).

Keyword Arguments:
  • params (TensorDictBase, optional) – A nested TensorDict containing the params to be passed to the functional value network module.

  • target_params (TensorDictBase, optional) – A nested TensorDict containing the target params to be passed to the functional value network module.

  • time_dim (int, optional) – the dimension corresponding to the time in the input tensordict. If not provided, defaults to the dimension marked with the "time" name if any, and to the last dimension otherwise. Negative dimensions are considered with respect to the input tensordict.

Returns:

An updated TensorDict with an advantage and a value_error keys as defined in the constructor.

Examples

>>> value_net = TensorDictModule(nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"])
>>> actor_net = TensorDictModule(nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"])
>>> actor_net = ProbabilisticActor(
...     module=actor_net,
...     in_keys=["logits"],
...     out_keys=["action"],
...     distribution_class=OneHotCategorical,
...     return_log_prob=True,
... )
>>> module = VTrace(
...     gamma=0.98,
...     value_network=value_net,
...     actor_network=actor_net,
...     differentiable=False,
... )
>>> obs, next_obs = torch.randn(2, 1, 10, 3)
>>> reward = torch.randn(1, 10, 1)
>>> done = torch.zeros(1, 10, 1, dtype=torch.bool)
>>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool)
>>> sample_log_prob = torch.randn(1, 10, 1)
>>> tensordict = TensorDict({
...     "obs": obs,
...     "done": done,
...     "terminated": terminated,
...     "sample_log_prob": sample_log_prob,
...     "next": {"obs": next_obs, "reward": reward, "done": done, "terminated": terminated},
... }, batch_size=[1, 10])
>>> _ = module(tensordict)
>>> assert "advantage" in tensordict.keys()

The module supports non-tensordict (i.e. unpacked tensordict) inputs too:

Examples

>>> value_net = TensorDictModule(nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"])
>>> actor_net = TensorDictModule(nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"])
>>> actor_net = ProbabilisticActor(
...     module=actor_net,
...     in_keys=["logits"],
...     out_keys=["action"],
...     distribution_class=OneHotCategorical,
...     return_log_prob=True,
... )
>>> module = VTrace(
...     gamma=0.98,
...     value_network=value_net,
...     actor_network=actor_net,
...     differentiable=False,
... )
>>> obs, next_obs = torch.randn(2, 1, 10, 3)
>>> reward = torch.randn(1, 10, 1)
>>> done = torch.zeros(1, 10, 1, dtype=torch.bool)
>>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool)
>>> sample_log_prob = torch.randn(1, 10, 1)
>>> tensordict = TensorDict({
...     "obs": obs,
...     "done": done,
...     "terminated": terminated,
...     "sample_log_prob": sample_log_prob,
...     "next": {"obs": next_obs, "reward": reward, "done": done, "terminated": terminated},
... }, batch_size=[1, 10])
>>> advantage, value_target = module(
...     obs=obs, next_reward=reward, next_done=done, next_obs=next_obs, next_terminated=terminated, sample_log_prob=sample_log_prob
... )