ACTLoss¶
- class torchrl.objectives.ACTLoss(*args, **kwargs)[source]¶
Loss module for Action Chunking with Transformers (ACT).
Implements the training objective from Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (Zhao et al., 2023), pairing an L1 chunk-reconstruction term with a KL-divergence penalty on the CVAE latent:
\[\mathcal{L} = \underbrace{\|a_{\text{pred}} - a_{\text{chunk}}\|_1}_{\text{reconstruction}} + \beta \cdot \underbrace{D_{\mathrm{KL}}\!\left(q(z|o,a)\,\|\, \mathcal{N}(0,I)\right)}_{\text{KL}}\]The
actor_networkmust read"observation"and"action_chunk"and write"action_pred","mu", and"log_var". This matches the contract ofACTModelwhen wrapped with aTensorDictModule.Three values are returned in the output TensorDict:
"loss_act"— the full (differentiable) training loss."loss_reconstruction"— detached L1 reconstruction term (for logging)."loss_kl"— detached KL term (for logging).
- Parameters:
actor_network (TensorDictModule) – ACT policy. Must expose
in_keyscontaining"observation"and"action_chunk"and write"action_pred","mu","log_var".- Keyword Arguments:
kl_weight (float, optional) – β — weight on the KL divergence term. Defaults to
10.0(as in the original paper).reduction (str, optional) –
"none"|"mean"|"sum". Defaults to"mean".
Examples
>>> import torch >>> from tensordict import TensorDict >>> from tensordict.nn import TensorDictModule >>> from torchrl.modules.models import ACTModel >>> from torchrl.objectives import ACTLoss >>> model = ACTModel(obs_dim=14, action_dim=7, chunk_size=10) >>> actor = TensorDictModule( ... model, ... in_keys=["observation", "action_chunk"], ... out_keys=["action_pred", "mu", "log_var"], ... ) >>> loss_fn = ACTLoss(actor, kl_weight=10.0) >>> td = TensorDict( ... { ... "observation": torch.randn(4, 14), ... "action_chunk": torch.randn(4, 10, 7), ... }, ... batch_size=[4], ... ) >>> loss_td = loss_fn(td) >>> loss_td["loss_act"].backward()
- forward(tensordict: TensorDictBase) TensorDict[source]¶
Compute the ACT loss.
- Parameters:
tensordict (TensorDictBase) – Input data containing
"observation"and"action_chunk".- Returns:
TensorDict with keys
"loss_act","loss_reconstruction", and"loss_kl".