Rate this Page

MeanActionSelector#

class torchrl.envs.transforms.MeanActionSelector(observation_key: str = 'observation', action_key: str = 'action')[source]#

Bridges Gaussian belief-space policies with standard environments.

Gaussian policies used in moment-matching model-based RL (e.g. PILCO) operate on state beliefs(mean, covariance) pairs – and produce action distributions with ("action", "mean"), ("action", "var"), etc. This transform adapts a standard environment so that such a policy can be used directly with rollout():

  • Forward (env output -> policy input): wraps the flat "observation" tensor into ("observation", "mean") with a zero-covariance ("observation", "var"), representing a deterministic state belief.

  • Inverse (policy output -> env input): extracts ("action", "mean") from the policy output and writes it as the flat "action" for the base environment step.

Parameters:
  • observation_key (str, optional) – The observation key to read from the base environment. Defaults to "observation".

  • action_key (str, optional) – The action key expected by the base environment. Defaults to "action".

Examples

>>> import torch
>>> from torchrl.envs import GymEnv, TransformedEnv
>>> from torchrl.envs.transforms import MeanActionSelector
>>> base_env = GymEnv("Pendulum-v1")
>>> env = TransformedEnv(base_env, MeanActionSelector())
>>> td = env.reset()
>>> # The policy now sees ("observation", "mean") and ("observation", "var")
>>> print(td["observation", "mean"].shape)
>>> print(td["observation", "var"].shape)
transform_observation_spec(observation_spec)[source]#

Transforms the observation spec such that the resulting spec matches transform mapping.

Parameters:

observation_spec (TensorSpec) – spec before the transform

Returns:

expected spec after the transform