Shortcuts

Source code for torchrl.modules.models.gp

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
# Variable naming follows Deisenroth & Rasmussen (2011), "PILCO: A Model-Based
# and Data-Efficient Approach to Policy Search" (cited inline as "Eq. N").
#
# Key symbols
# -----------
#   x̃  := [x, u]              concatenated state-action input  (Eq. 1 / Sec. 2.1)
#   Δ  := x_t - x_{t-1}       transition residual              (Sec. 2.1)
#   K_a                        Gram matrix  K_{a,ij}=k_a(x̃_i,x̃_j)  (Eq. 6)
#   β_a := (K_a + σ²_ε I)^{-1}y_a  GP weight vector           (Eq. 7)
#   q_a                        kernel-mean vector               (Eq. 15)
#   Q_{ab}                     cross-kernel matrix              (Eqs. 21-22)
#   μ̃  / Σ̃                    joint state-action mean/cov      (Sec. 2.2)
#   μ_Δ / Σ_Δ                  predictive mean/cov of Δ        (Eqs. 14, 17-23)
#   μ_t / Σ_t                  next-state mean/cov             (Eqs. 10-11)

import importlib.util

import torch
import torch.nn as nn
from tensordict import TensorDictBase

_has_gpytorch = importlib.util.find_spec("gpytorch") is not None
_has_botorch = importlib.util.find_spec("botorch") is not None


[docs] class GPWorldModel(nn.Module): """Gaussian Process world model with moment-matching uncertainty propagation. Implements the probabilistic dynamics model from PILCO (Deisenroth & Rasmussen, 2011). One independent GP is fit per state dimension, each predicting the transition residual ``Δ = x_t - x_{t-1}`` from the concatenated state-action input ``x̃ = [x, u]`` (Sec. 2.1). :meth:`forward` supports two modes depending on whether the input observation carries non-zero variance: - **Deterministic**: uses the GP posterior mean and variance directly (Eqs. 7-8). - **Uncertain** (moment-matching): propagates a Gaussian belief ``N(μ, Σ)`` through the GP analytically (Eqs. 10-23). .. note:: Requires ``botorch`` and ``gpytorch`` as optional dependencies. Args: obs_dim (int): Dimension D of the observation (state) space. action_dim (int): Dimension F of the action (control) space. in_keys (list of NestedKey, optional): Keys to read from the input :class:`~tensordict.TensorDictBase`. Must contain five entries in order: action mean, action covariance, state-action cross-covariance, observation mean, observation covariance. Defaults to ``[("action", "mean"), ("action", "var"), ("action", "cross_covariance"), ("observation", "mean"), ("observation", "var")]``. out_keys (list of NestedKey, optional): Keys to write the predicted next-state mean and covariance to. Defaults to ``[("next", "observation", "mean"), ("next", "observation", "var")]``. Examples: >>> import torch >>> from tensordict import TensorDict >>> model = GPWorldModel(obs_dim=4, action_dim=1) >>> dataset = TensorDict( ... { ... "observation": torch.randn(50, 4), ... "action": torch.randn(50, 1), ... ("next", "observation"): torch.randn(50, 4), ... }, ... batch_size=[50], ... ) >>> model.fit(dataset) Reference: Deisenroth, M. P. & Rasmussen, C. E. (2011). PILCO: A model-based and data-efficient approach to policy search. *ICML*. """ def __init__( self, obs_dim: int, action_dim: int, in_keys: list[str | tuple[str, ...]] | None = None, out_keys: list[str | tuple[str, ...]] | None = None, ) -> None: if not _has_botorch or not _has_gpytorch: raise ImportError( "botorch and gpytorch are required to use GPWorldModel. " "Please install them to proceed." ) super().__init__() self.obs_dim = obs_dim # D in the paper self.action_dim = action_dim # F in the paper self.state_action_dim = obs_dim + action_dim # D+F, dimension of x̃ (Sec. 2.1) self.in_keys = ( in_keys if in_keys is not None else [ ("action", "mean"), ("action", "var"), ("action", "cross_covariance"), ("observation", "mean"), ("observation", "var"), ] ) self.out_keys = ( out_keys if out_keys is not None else [ ("next", "observation", "mean"), ("next", "observation", "var"), ] ) self.model_list = None # X̃ = [x̃_1, ..., x̃_n] ∈ R^{n×(D+F)} – training inputs (Sec. 2.1) self.register_buffer("X_tilde_train", torch.empty(0)) # ℓ_a ∈ R^{D+F} – ARD length-scales for each output dimension a (Eq. 6). # Stored as [D, D+F]; the full matrix Λ_a = diag(ℓ_a²) is never # materialised — ℓ_a is squared on the fly wherever needed. # Note: GPyTorch's .lengthscale returns ℓ directly (not ℓ²). self.register_buffer("ell", torch.zeros(self.obs_dim, self.state_action_dim)) # α²_a – signal variance for each output dimension a (Eq. 6); shape [D, 1] self.register_buffer("alpha_sq", torch.zeros(self.obs_dim, 1)) # σ²_{ε_a} – noise variance for each output dimension a (Sec. 2.1); shape [D] self.register_buffer("sigma_sq_eps", torch.zeros(self.obs_dim)) # (K_a + σ²_{ε_a} I)^{-1} – cached inverse Gram matrices (Eq. 7); shape [D, n, n]. # Registered as buffers so they survive .to(device) and state_dict round-trips. self.register_buffer("_cached_inv_K_noisy", None) # β_a = (K_a + σ²_{ε_a} I)^{-1} y_a – GP weight vectors (Eq. 7); shape [D, n]. # Registered as a buffer so it survives .to(device) and state_dict round-trips. self.register_buffer("_cached_beta", None) @property def device(self) -> torch.device: return self.ell.device
[docs] def fit(self, dataset: TensorDictBase) -> None: """Fit one GP per state dimension to a dataset of transitions. Constructs training inputs ``X̃ = [x, u]`` and targets ``Δ_a = x_{t,a} - x_{t-1,a}``, then maximises the marginal log-likelihood to learn SE kernel hyper-parameters (ℓ_a, α²_a, σ²_{ε_a}) for each output dimension (Sec. 2.1, Eq. 6). .. note:: The dataset is expected to be flat with shape ``[n, *]``. If your replay buffer returns multi-dimensional batches (e.g. ``[B, T, *]``), call ``dataset.reshape(-1)`` before passing it here. Args: dataset (TensorDictBase): Transition dataset with keys ``"observation"`` of shape ``(n, D)``, ``"action"`` of shape ``(n, F)``, and ``("next", "observation")`` of shape ``(n, D)``. """ from botorch.fit import fit_gpytorch_mll from botorch.models import ModelListGP, SingleTaskGP from gpytorch.kernels import RBFKernel, ScaleKernel from gpytorch.mlls import SumMarginalLogLikelihood from gpytorch.priors import GammaPrior x_t_minus_1 = dataset["observation"] # x_{t-1} ∈ R^{n×D} u_t_minus_1 = dataset["action"] # u_{t-1} ∈ R^{n×F} x_t = dataset[("next", "observation")] # x_t ∈ R^{n×D} # x̃ = [x_{t-1}, u_{t-1}] ∈ R^{n×(D+F)} – training inputs (Sec. 2.1) X_tilde_train = ( torch.cat([x_t_minus_1, u_t_minus_1], dim=-1).detach().to(self.device) ) # Δ ∈ R^{n×D}, Δ_{i,a} = x_{t,a} - x_{t-1,a} – training targets (Sec. 2.1) Delta_train = (x_t - x_t_minus_1).detach().to(self.device) self.X_tilde_train = X_tilde_train models = [] for a in range(self.obs_dim): # Each GP_a models p(Δ_a | x̃) independently (Sec. 2.1) Delta_a = Delta_train[:, a].unsqueeze(-1) # y_a ∈ R^{n×1} covar_module = ScaleKernel( # SE kernel k_a(x̃, x̃') with ARD length-scales (one ℓ_{a,i} # per input dimension, Eq. 6) RBFKernel( ard_num_dims=self.state_action_dim, lengthscale_prior=GammaPrior(1.1, 0.1), ), outputscale_prior=GammaPrior(1.5, 0.5), # prior on α²_a (Eq. 6) ) gp_a = SingleTaskGP( train_X=X_tilde_train, train_Y=Delta_a, covar_module=covar_module, ) gp_a.likelihood.noise_covar.register_prior( "noise_prior", GammaPrior(1.2, 0.05), "noise", # prior on σ²_{ε_a} (Sec. 2.1) ) models.append(gp_a) self.model_list = ModelListGP(*models).to(self.device) mll = SumMarginalLogLikelihood(self.model_list.likelihood, self.model_list) fit_gpytorch_mll(mll) # evidence maximisation (Sec. 2.1) self._extract_and_cache_parameters(Delta_train)
def _extract_and_cache_parameters(self, Delta_train: torch.Tensor) -> None: # Extract learned hyper-parameters from each GP_a and pre-compute the # quantities that are fixed after fitting: # ℓ_a, α²_a, σ²_{ε_a} (Eq. 6 / Sec. 2.1) # (K_a + σ²_{ε_a} I)^{-1} (Eq. 7) # β_a = (K_a + σ²_{ε_a} I)^{-1} y_a (Eq. 7) ell_list, alpha_sq_list, sigma_sq_eps_list = [], [], [] inv_K_noisy_list, beta_list = [], [] n = self.X_tilde_train.shape[0] # number of training points for a, gp_a in enumerate(self.model_list.models): gp_a.eval() gp_a.likelihood.eval() # ℓ_a ∈ R^{D+F} – ARD length-scales for GP_a (Eq. 6). # GPyTorch's .lengthscale returns ℓ directly (not ℓ²). ell_a = gp_a.covar_module.base_kernel.lengthscale.squeeze().detach() # α²_a – signal variance for GP_a (Eq. 6) alpha_sq_a = gp_a.covar_module.outputscale.detach() # σ²_{ε_a} – noise variance for GP_a (Sec. 2.1) sigma_sq_eps_a = gp_a.likelihood.noise.squeeze().detach() ell_list.append(ell_a) alpha_sq_list.append(alpha_sq_a) sigma_sq_eps_list.append(sigma_sq_eps_a) # K_{a,ij} = α²_a exp(-½ (x̃_i-x̃_j)^T Λ_a^{-1} (x̃_i-x̃_j)) (Eq. 6) # Dividing X̃ by ℓ_a gives Λ_a^{-1/2}-scaled inputs for cdist. X_tilde_scaled = self.X_tilde_train / ell_a sq_dist = torch.cdist(X_tilde_scaled, X_tilde_scaled, p=2) ** 2 K_a = alpha_sq_a * torch.exp(-0.5 * sq_dist) # K_{a,noisy} = K_a + σ²_{ε_a} I (denominator in Eq. 7) K_a_noisy = K_a + (sigma_sq_eps_a + 1e-6) * torch.eye(n, device=self.device) L_a = torch.linalg.cholesky(K_a_noisy) eye_n = torch.eye(n, dtype=L_a.dtype, device=L_a.device) # (K_a + σ²_{ε_a} I)^{-1} (Eq. 7) inv_K_a_noisy = torch.cholesky_solve(eye_n, L_a) # y_a = [Δ_{1,a}, ..., Δ_{n,a}]^T – targets for GP_a (Sec. 2.1) y_a = Delta_train[:, a].unsqueeze(-1) # β_a = (K_a + σ²_{ε_a} I)^{-1} y_a (Eq. 7) beta_a = torch.cholesky_solve(y_a, L_a).squeeze(-1) inv_K_noisy_list.append(inv_K_a_noisy) beta_list.append(beta_a) self.ell = torch.stack(ell_list) # [D, D+F] self.alpha_sq = torch.stack(alpha_sq_list).unsqueeze(-1) # [D, 1] self.sigma_sq_eps = torch.stack(sigma_sq_eps_list) # [D] self._cached_inv_K_noisy = torch.stack(inv_K_noisy_list) # [D, n, n] self._cached_beta = torch.stack(beta_list) # [D, n]
[docs] def compute_factorizations(self) -> tuple[torch.Tensor, torch.Tensor]: """Return the cached kernel inverses and GP weight vectors. Returns: tuple[Tensor, Tensor]: A pair ``(inv_K_noisy, beta)`` where ``inv_K_noisy`` has shape ``(D, n, n)`` and contains ``(K_a + σ²_{ε_a} I)^{-1}`` for each output dimension (Eq. 7), and ``beta`` has shape ``(D, n)`` and contains ``β_a = (K_a + σ²_{ε_a} I)^{-1} y_a`` (Eq. 7). """ return self._cached_inv_K_noisy, self._cached_beta
def _gather_gp_hyperparams(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # Returns (ell, alpha_sq, sigma_sq_eps) — the SE kernel hyper-parameters # for each GP_a (Eq. 6 / Sec. 2.1): # ell: ℓ_{a,i}, shape [D, D+F] (ℓ, not ℓ²) # alpha_sq: α²_a, shape [D, 1] # sigma_sq_eps: σ²_{ε_a}, shape [D] return self.ell, self.alpha_sq, self.sigma_sq_eps
[docs] def forward(self, tensordict: TensorDictBase) -> TensorDictBase: """Predict the next-state distribution given the current state and action. Routes to :meth:`uncertain_forward` (moment-matching, Eqs. 10-23) when any input covariance is non-zero, and to :meth:`deterministic_forward` (Eqs. 7-8) otherwise. Args: tensordict (TensorDictBase): Input tensordict containing keys defined by ``in_keys``. Observation and action tensors may be unbatched ``(D,)`` / ``(F,)`` or batched ``(B, D)`` / ``(B, F)``; a leading batch dimension will be added and removed automatically for unbatched inputs. The observation covariance, when present, must be a full matrix of shape ``(..., D, D)`` — per-dimension variance vectors are not accepted; use :func:`torch.diag_embed` to convert them first. Returns: TensorDictBase: The same tensordict, updated in-place with the predicted next-state mean and covariance written to ``out_keys``. """ u_mean_key, u_var_key, u_cc_key, x_mean_key, x_var_key = self.in_keys Sigma_x = tensordict.get(x_var_key, None) if Sigma_x is not None and Sigma_x.dim() < 2: raise ValueError( f"Expected observation covariance to have at least 2 dimensions " f"(..., D, D), got shape {tuple(Sigma_x.shape)}. " "Convert per-dimension variances with torch.diag_embed() first." ) Sigma_u = tensordict.get(u_var_key, None) Sigma_xu = tensordict.get(u_cc_key, None) def _has_nonzero_covariance(covariance: torch.Tensor | None) -> bool: return covariance is not None and not torch.all( torch.isclose(covariance, torch.zeros_like(covariance)) ) observation_uncertain = _has_nonzero_covariance(Sigma_x) action_uncertain = _has_nonzero_covariance(Sigma_u) cross_uncertain = _has_nonzero_covariance(Sigma_xu) if observation_uncertain or action_uncertain or cross_uncertain: return self.uncertain_forward(tensordict) else: return self.deterministic_forward(tensordict)
[docs] def uncertain_forward(self, tensordict: TensorDictBase) -> TensorDictBase: """Moment-matching forward pass for a Gaussian input belief (Eqs. 10-23). Propagates the joint Gaussian belief ``p(x̃_{t-1}) = N(μ̃_{t-1}, Σ̃_{t-1})`` (Sec. 2.2) through the GP dynamics model and returns a Gaussian approximation to ``p(x_t)`` via exact moment matching. Args: tensordict (TensorDictBase): Input tensordict with keys defined by ``in_keys``. Supports unbatched ``(D,)`` inputs or batched inputs with a single leading batch dimension ``(B, D)``. Returns: TensorDictBase: The same tensordict updated with next-state mean ``μ_t`` (Eq. 10) and covariance ``Σ_t`` (Eq. 11) at ``out_keys``. """ inv_K_noisy, beta = self.compute_factorizations() ell, alpha_sq, sigma_sq_eps = self._gather_gp_hyperparams() u_mean_key, u_var_key, u_cc_key, x_mean_key, x_var_key = self.in_keys mu_x = tensordict.get(x_mean_key) # μ_x, shape (B×)D Sigma_x = tensordict.get(x_var_key) # Σ_x, shape (B×)D×D mu_u = tensordict.get(u_mean_key) # μ_u, shape (B×)F Sigma_u = tensordict.get(u_var_key) # Σ_u, shape (B×)F×F C_xu = tensordict.get( u_cc_key ) # ∂μ_u/∂μ_x (policy Jacobian), (B×)D×F (Eq. 12) # Support unbatched inputs by temporarily adding a leading batch dimension. unbatched = mu_x.dim() == 1 if unbatched: mu_x, Sigma_x, mu_u, Sigma_u, C_xu = ( mu_x.unsqueeze(0), Sigma_x.unsqueeze(0), mu_u.unsqueeze(0), Sigma_u.unsqueeze(0), C_xu.unsqueeze(0), ) device, dtype = mu_x.device, mu_x.dtype B = mu_x.shape[0] # batch size n = self.X_tilde_train.shape[0] # number of training points D = self.obs_dim # state dimension DF = self.state_action_dim # D+F, dimension of x̃ # ---- Build joint state-action distribution p(x̃_{t-1}) (Sec. 2.2) ---- # μ̃_{t-1} = [μ_x; μ_u] ∈ R^{B×(D+F)} mu_tilde = torch.cat([mu_x, mu_u], dim=-1) # Σ̃_{t-1} = [[Σ_x, Σ_x ∂π/∂μ_x ], # [(∂π/∂μ_x)^T Σ_x, Σ_u ]] ∈ R^{B×(D+F)×(D+F)} # C_xu is the policy Jacobian ∂μ_u/∂μ_x, so cov(x,u) = Σ_x @ C_xu (Eq. 12) Sigma_x_C_xu = Sigma_x @ C_xu # upper-right block [B, D, F] Sigma_tilde = torch.cat( [ torch.cat([Sigma_x, Sigma_x_C_xu], dim=-1), torch.cat([Sigma_x_C_xu.transpose(-1, -2), Sigma_u], dim=-1), ], dim=-2, ) # [B, D+F, D+F] # ---- Compute q_a (mean-prediction kernel vector, Eq. 15) ---- # ν_i = x̃_i - μ̃_{t-1} (Eq. 16); shape [B, n, D+F] nu = self.X_tilde_train - mu_tilde.unsqueeze(1) # Λ_a^{-1} as diagonal matrices; shape [D, D+F, D+F]. # ell stores ℓ_a (not ℓ²_a), so 1/ℓ_a gives the diagonal of Λ_a^{-1/2}; # used here to form the full Λ_a^{-1} = diag(1/ℓ²_a) = diag(1/ℓ_a)². inv_Lambda_diag_mats = torch.diag_embed(1.0 / ell).to( device=device, dtype=dtype ) # Λ_a^{-1} ν_i; shape [B, D, n, D+F] inv_Lambda_nu = nu.unsqueeze(1) @ inv_Lambda_diag_mats.unsqueeze(0) # R_a = Λ_a^{-1} Σ̃_{t-1} Λ_a^{-1} + I – normalising matrix in Eq. 15; # shape [B, D, D+F, D+F] R_a = ( inv_Lambda_diag_mats.unsqueeze(0)
[docs] @ Sigma_tilde.unsqueeze(1) @ inv_Lambda_diag_mats.unsqueeze(0) ) R_a = R_a + torch.eye(DF, device=device, dtype=dtype).view(1, 1, DF, DF) # Solve R_a t = (Λ_a^{-1} ν_i)^T → t = R_a^{-1} Λ_a^{-1} ν_i^T t = torch.linalg.solve(R_a, inv_Lambda_nu.transpose(-2, -1)).transpose(-2, -1) # exp(-½ ν_i^T (Σ̃ + Λ_a)^{-1} ν_i) – exponent in Eq. 15; shape [B, D, n] scaled_exp = torch.exp(-0.5 * torch.sum(inv_Lambda_nu * t, dim=-1)) # Scalar prefactor α²_a / sqrt(|Σ̃_{t-1} Λ_a^{-1} + I|) from Eq. 15; shape [B, D] det_R_a = torch.linalg.det(R_a) c_a = alpha_sq.squeeze(-1).unsqueeze(0) / torch.sqrt(det_R_a) # β_a ⊙ q_a (pointwise); shape [B, D, n] beta_q_a = scaled_exp * beta.unsqueeze(0) # μ^a_Δ = β_a^T q_a (Eq. 14); shape [B, D] mu_Delta = torch.sum(beta_q_a, dim=-1) * c_a.squeeze(0) # ---- Cross-covariance cov[x̃_{t-1}, Δ_t] (used in Eq. 12) ---- # Derivative of μ_Δ w.r.t. μ̃, contracted with Σ̃ (Deisenroth 2010); # shape [B, D+F, D] t_inv_Lambda = t @ inv_Lambda_diag_mats.unsqueeze(0) cov_xtilde_Delta = ( torch.matmul( t_inv_Lambda.transpose(-2, -1), beta_q_a.unsqueeze(-1) ).squeeze(-1) * c_a.unsqueeze(-1) ).transpose(-2, -1) # ---- Compute Q_{ab} (cross-kernel matrix, Eqs. 21-22) ---- X_i = self.X_tilde_train.unsqueeze(1) # [n, 1, D+F] X_j = self.X_tilde_train.unsqueeze(0) # [1, n, D+F] diff_ij = X_i - X_j # x̃_i - x̃_j; [n, n, D+F] (Eq. 22) # ell stores ℓ_a; ℓ²_a is the diagonal of Λ_a (Eq. 6) ell_sq_a = (ell**2)[:, None, :] # [D, 1, D+F] ell_sq_b = (ell**2)[None, :, :] # [1, D, D+F] # Λ_{ab} = (Λ_a^{-1} + Λ_b^{-1})^{-1}, diagonal entries; [D, D, D+F] inv_ell_sq_sum = 1.0 / ell_sq_a + 1.0 / ell_sq_b Lambda_ab = 1.0 / inv_ell_sq_sum # First exponential in Q_{ab,ij}: kernel product at training inputs (Eq. 22) # -½ (x̃_i - x̃_j)^T (Λ_a + Λ_b)^{-1} (x̃_i - x̃_j); shape [D, D, n, n] inv_ell_sq_sum_ab = 1.0 / (ell_sq_a + ell_sq_b) exp1 = -0.5 * torch.sum( diff_ij.unsqueeze(0).unsqueeze(0) * inv_ell_sq_sum_ab.unsqueeze(2).unsqueeze(2) * diff_ij.unsqueeze(0).unsqueeze(0), dim=-1, ) # [D, D, n, n] # z̄_{ij} = Λ_{ab} (Λ_a^{-1} x̃_i + Λ_b^{-1} x̃_j) – midpoint (Eq. 22); # shape [D, D, n, n, D+F] z_bar = Lambda_ab.unsqueeze(2).unsqueeze(2) * ( X_i.unsqueeze(0).unsqueeze(0) / ell_sq_a.unsqueeze(2).unsqueeze(2) + X_j.unsqueeze(0).unsqueeze(0) / ell_sq_b.unsqueeze(2).unsqueeze(2) ) # z_{ij} = z̄_{ij} - μ̃_{t-1}; shape [B, D, D, n, n, D+F] z_bar = z_bar.unsqueeze(0).expand(B, -1, -1, -1, -1, -1) z_ij = z_bar - mu_tilde[:, None, None, None, None, :] z_ij_flat = z_ij.view(B, D, D, n * n, DF) # M_{ab} = Σ̃_{t-1} + diag(Λ_{ab}) – matrix in second exp of Eq. 22; # shape [B, D, D, D+F, D+F] M_ab = Sigma_tilde[:, None, None] + torch.diag_embed(Lambda_ab) # Second exponential: -½ z_{ij}^T M_{ab}^{-1} z_{ij}; shape [B, D, D, n, n] M_ab_solved = torch.linalg.solve(M_ab, z_ij_flat.transpose(-2, -1)).transpose( -2, -1 ) exp2 = (-0.5 * torch.sum(z_ij_flat * M_ab_solved, dim=-1)).view(B, D, D, n, n) # R_{ab} = Σ̃_{t-1} (Λ_a^{-1} + Λ_b^{-1}) + I – normalising matrix (Eq. 22); # shape [B, D, D, D+F, D+F] R_ab = Sigma_tilde[:, None, None] @ torch.diag_embed( inv_ell_sq_sum ) + torch.eye(DF, device=device, dtype=dtype) det_R_ab = torch.linalg.det(R_ab) # [B, D, D] # Scalar prefactor α²_a α²_b / sqrt(|R_{ab}|) (Eq. 22); shape [B, D, D] c_ab = (alpha_sq.view(1, D, 1) * alpha_sq.view(1, 1, D)) / torch.sqrt(det_R_ab) # Q_{ab,ij} (Eq. 22); shape [B, D, D, n, n] Q_ab = c_ab.unsqueeze(-1).unsqueeze(-1) * torch.exp(exp1.unsqueeze(0) + exp2) # ---- Σ_Δ = predictive covariance of Δ (Eqs. 17-23) ---- # Off-diagonal entries: σ²_{ab} = β_a^T Q_{ab} β_b - μ^a_Δ μ^b_Δ (Eqs. 18, 20) beta_a = beta.view(1, D, 1, n) # [1, D, 1, n] beta_b = beta.view(1, 1, D, n) # [1, 1, D, n] Q_ab_beta_b = torch.matmul(Q_ab, beta_b.unsqueeze(-1)).squeeze( -1 ) # [B, D, D, n] Sigma_Delta = ( torch.matmul(beta_a.unsqueeze(-2), Q_ab_beta_b.unsqueeze(-1)) .squeeze(-1) .squeeze(-1) ) # [B, D, D] – β_a^T Q_{ab} β_b (Eq. 20) # Diagonal correction E_{x̃}[var_f[Δ_a | x̃]] = α²_a - tr(K_a^{-1} Q_{aa}) # added to σ²_{aa} (Eqs. 17, 23) invK_Q = torch.matmul( inv_K_noisy.unsqueeze(0).unsqueeze(2), # [1, D, 1, n, n] Q_ab, # [B, D, D, n, n] ) # [B, D, D, n, n] trace_invK_Q = torch.diagonal(invK_Q, dim1=-2, dim2=-1).sum(-1) # [B, D, D] diag_idx = torch.arange(D, device=device) alpha_sq_b = alpha_sq.squeeze(-1).unsqueeze(0).expand(B, -1) # [B, D] sigma_sq_eps_b = sigma_sq_eps.unsqueeze(0).expand(B, -1) # [B, D] # Add α²_a - tr(K_a^{-1} Q_{aa}) + σ²_{ε_a} to the diagonal (Eqs. 17, 23) Sigma_Delta[:, diag_idx, diag_idx] += ( alpha_sq_b - trace_invK_Q[:, diag_idx, diag_idx] + sigma_sq_eps_b ) # Subtract outer product of means: Σ_Δ -= μ_Δ μ_Δ^T (Eqs. 17-18) Sigma_Delta = Sigma_Delta - torch.bmm( mu_Delta.unsqueeze(-1), mu_Delta.unsqueeze(-2) ) Sigma_Delta = ( Sigma_Delta + Sigma_Delta.transpose(-2, -1) ) / 2 # enforce symmetry # ---- Propagate to next-state belief (Eqs. 10-12) ---- # cov[x_{t-1}, Δ_t] = cov[x_{t-1}, x̃_{t-1}] · cov_xtilde_Delta (Eq. 12) # cov[x_{t-1}, x̃_{t-1}] is the top-D rows of Σ̃_{t-1}: shape [B, D, D+F]. # Using only Sigma_x_C_xu ([B, D, F]) here would be wrong — it drops # the Σ_x block and produces a [B, D, F] @ [B, D+F, D] shape mismatch. Sigma_x_rows = Sigma_tilde[:, :D, :] # [B, D, D+F] cov_x_Delta = Sigma_x_rows @ cov_xtilde_Delta # [B, D, D] # μ_t = μ_{t-1} + μ_Δ (Eq. 10) mu_t = mu_x + mu_Delta # Σ_t = Σ_{t-1} + Σ_Δ + cov[x_{t-1},Δ_t] + cov[Δ_t,x_{t-1}] (Eq. 11) Sigma_t = Sigma_x + Sigma_Delta + cov_x_Delta + cov_x_Delta.transpose(-2, -1) Sigma_t = (Sigma_t + Sigma_t.transpose(-2, -1)) / 2 # enforce symmetry Sigma_t = Sigma_t + 1e-8 * torch.eye(D, device=device).expand( B, -1, -1 ) # jitter if unbatched: mu_t = mu_t.squeeze(0) Sigma_t = Sigma_t.squeeze(0) out_mean_key, out_var_key = self.out_keys tensordict.set(out_mean_key, mu_t) tensordict.set(out_var_key, Sigma_t) return tensordict
def deterministic_forward(self, tensordict: TensorDictBase) -> TensorDictBase: """Deterministic forward pass using GP posterior mean and variance (Eqs. 7-8). Used when the input observation is a point estimate with no uncertainty. Returns the GP posterior mean ``m_f(x̃_*)`` (Eq. 7) and per-dimension variance ``σ²_f(x̃_*)`` (Eq. 8) for each state dimension. Args: tensordict (TensorDictBase): Input tensordict with keys defined by ``in_keys``. Supports arbitrary leading batch dimensions ``(*batch, D)`` / ``(*batch, F)``, as well as unbatched ``(D,)`` / ``(F,)`` inputs. Returns: TensorDictBase: The same tensordict updated with next-state mean ``μ_t`` and diagonal covariance ``Σ_t = diag(σ²_Δ)`` at ``out_keys``. """ u_mean_key, u_var_key, u_cc_key, x_mean_key, x_var_key = self.in_keys mu_x = tensordict.get(x_mean_key) # x_{t-1}, shape (*batch, D) or (D,) mu_u = tensordict.get(u_mean_key) # u_{t-1}, shape (*batch, F) or (F,) batch_shape = mu_x.shape[:-1] # leading dims; () for unbatched inputs # Flatten all leading batch dimensions to a single axis for the GP # posterior call, then restore the original shape afterwards. x_flat = mu_x.reshape(-1, self.obs_dim) # [B_flat, D] u_flat = mu_u.reshape(-1, self.action_dim) # [B_flat, F] # x̃_* = [x_{t-1}, u_{t-1}] ∈ R^{B_flat×(D+F)} (Sec. 2.1) X_tilde_test = torch.cat([x_flat, u_flat], dim=-1) # GP posterior mean m_f(x̃_*) (Eq. 7) and std σ_f(x̃_*) (Eq. 8) mu_Delta_list, sigma_Delta_list = [], [] with torch.no_grad(): for gp_a in self.model_list.models: posterior_a = gp_a.posterior(X_tilde_test) mu_Delta_list.append(posterior_a.mean.squeeze(-1)) # m_f (Eq. 7) sigma_Delta_list.append( torch.sqrt(posterior_a.variance).squeeze(-1) # σ_f (Eq. 8) ) # μ_Δ – predicted residual mean; restore original batch shape mu_Delta = torch.stack(mu_Delta_list, dim=-1).view(*batch_shape, self.obs_dim) # σ_Δ – predicted residual std; restore original batch shape sigma_Delta = torch.stack(sigma_Delta_list, dim=-1).view( *batch_shape, self.obs_dim ) # μ_t = x_{t-1} + μ_Δ (deterministic version of Eq. 10) mu_t = mu_x + mu_Delta # Σ_t = diag(σ²_Δ) – diagonal covariance from independent GP variances (Eq. 8) Sigma_t = torch.diag_embed(sigma_Delta**2) out_mean_key, out_var_key = self.out_keys tensordict.set(out_mean_key, mu_t) tensordict.set(out_var_key, Sigma_t) return tensordict

Docs

Lorem ipsum dolor sit amet, consectetur

View Docs

Tutorials

Lorem ipsum dolor sit amet, consectetur

View Tutorials

Resources

Lorem ipsum dolor sit amet, consectetur

View Resources