TD0Estimator#
- class torchrl.objectives.value.TD0Estimator(*args, **kwargs)[source]#
Temporal Difference (TD(0)) estimate of advantage function.
AKA bootstrapped temporal difference or 1-step return.
- Keyword Arguments:
gamma (scalar) – exponential mean discount.
value_network (TensorDictModule) – value operator used to retrieve the value estimates.
shifted (bool, optional) –
controls how value and next-value are obtained from the value network.
False(default) calls the value network twice (once on the root tensordict, once on"next"), which is correct whenever"next"may differ non-trivially fromobs[t+1]. Truthy values request a single call:True: fixed-budget single-call path. Inserts the true("next", <in_key>)entry after every internal truncation (done & ~terminated), shifts subsequent samples to the right inside a sequence of lengthT + shifted_budgetand masks the displaced suffix via"shifted_valid". Terminal steps (done & terminated) do not consume budget. Retained samples use exact next observations.
Note
Single-step rollout assumption.
shifted=Truerelies on the standard one-step rollout layout produced byenv.step+ auto-reset: at every position wheredone[t] = False, the value-net inputs in("next", <in_key>)[t]are expected to equal<in_key>[t+1]. The backend uses this invariant to evaluateVonce over a fused[T + shifted_budget]sequence instead of twice over[T]streams.The canonical pipeline that breaks the invariant is multi-step return processing (
MultiStep/ n-step bootstrapping), which rewrites("next", obs)[t]toobs[t+n]withn > 1.shifted=Trueis unsupported with multi-step returns — useshifted=Falseinstead.Single-call paths also require that the parameters at time
tandt+1are identical (i.e.target_paramsis not used).Defaults to
False.average_rewards (bool, optional) – if
True, rewards will be standardized before the TD is computed.differentiable (bool, optional) –
if
True, gradients are propagated through the computation of the value function. Default isFalse.Note
The proper way to make the function call non-differentiable is to decorate it in a torch.no_grad() context manager/decorator or pass detached parameters for functional modules.
skip_existing (bool, optional) – if
True, the value network will skip modules which outputs are already present in the tensordict. Defaults toNone, i.e., the value oftensordict.nn.skip_existing()is not affected.advantage_key (str or tuple of str, optional) – [Deprecated] the key of the advantage entry. Defaults to
"advantage".value_target_key (str or tuple of str, optional) – [Deprecated] the key of the advantage entry. Defaults to
"value_target".value_key (str or tuple of str, optional) – [Deprecated] the value key to read from the input tensordict. Defaults to
"state_value".device (torch.device, optional) – the device where the buffers will be instantiated. Defaults to
torch.get_default_device().deactivate_vmap (bool, optional) – whether to deactivate vmap calls and replace them with a plain for loop. Defaults to
False.value_chunk_size (int, optional) – if set, splits value-network calls into chunks of this many elements along the leading dimension. Defaults to
None.num_chunks (int, optional) – if set, splits value-network calls into this many chunks along the leading dimension. Mutually exclusive with
value_chunk_size.num_chunkis accepted as an alias. Defaults toNone.num_chunk (int, optional) – alias for
num_chunks. Cannot be set together with a differentnum_chunksvalue. Defaults toNone.shifted_budget (int, optional) – number of extra value-network time slots used when
shifted=True.1uses aT+1budget,2can represent one internal reset plus the rollout boundary without dropping samples, and so on. Defaults to1.
- forward(tensordict: TensorDictBase = None, *, params: TensorDictBase | None = None, target_params: TensorDictBase | None = None) TensorDictBase[source]#
Computes the TD(0) advantage given the data in tensordict.
If a functional module is provided, a nested TensorDict containing the parameters (and if relevant the target parameters) can be passed to the module.
- Parameters:
tensordict (TensorDictBase) – A TensorDict containing the data (an observation key,
"action",("next", "reward"),("next", "done"),("next", "terminated"), and"next"tensordict state as returned by the environment) necessary to compute the value estimates and the TDEstimate. The data passed to this module should be structured as[*B, T, *F]whereBare the batch size,Tthe time dimension andFthe feature dimension(s). The tensordict must have shape[*B, T].- Keyword Arguments:
params (TensorDictBase, optional) – A nested TensorDict containing the params to be passed to the functional value network module.
target_params (TensorDictBase, optional) – A nested TensorDict containing the target params to be passed to the functional value network module.
- Returns:
An updated TensorDict with an advantage and a value_error keys as defined in the constructor.
Examples
>>> from tensordict import TensorDict >>> value_net = TensorDictModule( ... nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"] ... ) >>> module = TDEstimate( ... gamma=0.98, ... value_network=value_net, ... ) >>> obs, next_obs = torch.randn(2, 1, 10, 3) >>> reward = torch.randn(1, 10, 1) >>> done = torch.zeros(1, 10, 1, dtype=torch.bool) >>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool) >>> tensordict = TensorDict({"obs": obs, "next": {"obs": next_obs, "done": done, "terminated": terminated, "reward": reward}}, [1, 10]) >>> _ = module(tensordict) >>> assert "advantage" in tensordict.keys()
The module supports non-tensordict (i.e. unpacked tensordict) inputs too:
Examples
>>> value_net = TensorDictModule( ... nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"] ... ) >>> module = TDEstimate( ... gamma=0.98, ... value_network=value_net, ... ) >>> obs, next_obs = torch.randn(2, 1, 10, 3) >>> reward = torch.randn(1, 10, 1) >>> done = torch.zeros(1, 10, 1, dtype=torch.bool) >>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool) >>> advantage, value_target = module(obs=obs, next_reward=reward, next_done=done, next_obs=next_obs, next_terminated=terminated)
- value_estimate(tensordict, target_params: TensorDictBase | None = None, next_value: Tensor | None = None, **kwargs)[source]#
Gets a value estimate, usually used as a target value for the value network.
If the state value key is present under
tensordict.get(("next", self.tensor_keys.value))then this value will be used without recurring to the value network.- Parameters:
tensordict (TensorDictBase) – the tensordict containing the data to read.
target_params (TensorDictBase, optional) – A nested TensorDict containing the target params to be passed to the functional value network module.
next_value (torch.Tensor, optional) – the value of the next state or state-action pair. Exclusive with
target_params.**kwargs – the keyword arguments to be passed to the value network.
Returns: a tensor corresponding to the state value.