RSSMRollout¶
- class torchrl.modules.RSSMRollout(*args, **kwargs)[source]¶
Rollout the RSSM network.
Given a set of encoded observations and actions, this module will rollout the RSSM network to compute all the intermediate states and beliefs. The previous posterior is used as the prior for the next time step. The forward method returns a stack of all intermediate states and beliefs.
Reference: https://arxiv.org/abs/1811.04551
- Parameters:
rssm_prior (TensorDictModule) – Prior network.
rssm_posterior (TensorDictModule) – Posterior network.
- forward(tensordict)[source]¶
Runs a rollout of simulated transitions in the latent space given a sequence of actions and environment observations.
The rollout requires a belief and posterior state primer.
At each step, two probability distributions are built and sampled from:
- A prior distribution p(s_{t+1} | s_t, a_t, b_t) where b_t is a
deterministic transform of the form b_t(s_{t-1}, a_{t-1}). The previous state s_t is sampled according to the posterior distribution (see below), creating a chain of posterior-to-priors that accumulates evidence to compute a prior distribution over the current event distribution: p(s_{t+1} s_t | o_t, a_t, s_{t-1}, a_{t-1}) = p(s_{t+1} | s_t, a_t, b_t) q(s_t | b_t, o_t)
- A posterior distribution of the form q(s_{t+1} | b_{t+1}, o_{t+1})
which amends to q(s_{t+1} | s_t, a_t, o_{t+1})