DiffusionBCLoss¶
- class torchrl.objectives.DiffusionBCLoss(*args, **kwargs)[source]¶
Behavioural Cloning loss for diffusion-based policies.
Implements the ε-prediction (noise-prediction) denoising loss from Diffusion Policy: Visuomotor Policy Learning via Action Diffusion (Chi et al., RSS 2023).
Given a batch of (observation, clean_action) pairs from a demonstration dataset, the loss:
Samples a random diffusion timestep
tfor each item in the batch.Corrupts the clean action with Gaussian noise via the DDPM forward process:
noisy_action = sqrt(ᾱ_t) * action + sqrt(1 - ᾱ_t) * ε.Asks the score network to predict the noise
ε.Returns the MSE between the predicted and actual noise.
This loss is designed to be used together with
DiffusionActor. The actor’s inner_DDPMModuleis accessed viaactor_network.moduleand itsadd_noisemethod is used for step 2.- Parameters:
actor_network (TensorDictModule) – a
DiffusionActor(or anyTensorDictModulewhose.moduleexposesadd_noise(clean_action, t)and ascore_networkattribute).- Keyword Arguments:
reduction (str, optional) – Specifies the reduction to apply to the output:
"none"|"mean"|"sum". Defaults to"mean".
Note
The tensordict passed to
forward()must contain:self.tensor_keys.action— the clean (demonstration) action.self.tensor_keys.observation— the conditioning observation.
Examples
>>> import torch >>> from tensordict import TensorDict >>> from torchrl.modules import DiffusionActor >>> from torchrl.objectives import DiffusionBCLoss >>> actor = DiffusionActor(action_dim=2, obs_dim=4, num_steps=10) >>> loss_fn = DiffusionBCLoss(actor) >>> td = TensorDict( ... { ... "observation": torch.randn(8, 4), ... "action": torch.randn(8, 2), ... }, ... batch_size=[8], ... ) >>> loss_td = loss_fn(td) >>> loss_td["loss_diffusion_bc"].backward()
- forward(tensordict: TensorDictBase) TensorDict[source]¶
Compute the diffusion BC loss.
- Parameters:
tensordict (TensorDictBase) – input data containing observations and clean demonstration actions.
- Returns:
TensorDict with key
"loss_diffusion_bc".