MujocoPlaygroundWrapper#
- torchrl.envs.MujocoPlaygroundWrapper(*args, **kwargs)[source]#
Google DeepMind MuJoCo Playground environment wrapper.
MuJoCo Playground is a collection of JAX-based MJX environments spanning locomotion, manipulation, and dm_control suite tasks.
GitHub: google-deepmind/mujoco_playground
- Parameters:
env (mujoco_playground._src.mjx_env.MjxEnv) – the environment to wrap.
agent_mapping (
MujocoPlaygroundAgentMappingor str, optional) – if provided, the environment is decomposed into a cooperative multi-agent task. Can be either aMujocoPlaygroundAgentMappinginstance or a string key intoKNOWN_MARL_MAPPINGS. Known string values:"ant_4x2","halfcheetah_6x1","hopper_3x1","humanoid_9|8","walker2d_2x3". Defaults toNone(single-agent mode).
- Keyword Arguments:
from_pixels (bool, optional) – Not yet supported.
frame_skip (int, optional) – if provided, indicates for how many steps the same action is to be repeated. The observation returned will be the last observation of the sequence, whereas the reward will be the sum of rewards across steps.
device (torch.device, optional) – if provided, the device on which the data is to be cast. Defaults to
torch.device("cpu").batch_size (torch.Size, optional) – the batch size of the environment. In
mujoco_playground, this controls the number of environments simulated in parallel via JAX’svmapon a single device (GPU/TPU). Defaults totorch.Size([]).allow_done_after_reset (bool, optional) – if
True, it is tolerated for envs to bedonejust afterreset()is called. Defaults toFalse.
- Variables:
available_envs – environments available to build
Note
Unlike
BraxWrapper, this wrapper does not copy the underlying JAX env state into the outputTensorDict. The state is kept on the env instance (self._current_state) and rolled forward by_step; this avoids round-tripping MJX/pytree state throughTensorDict, which would break MJX’s metadata pytree registration. As a consequence, the outputTensorDictonly containsobservation(or per-key obs for dict-obs envs),reward,doneandterminated— there is nostatekey.Warning
Because the JAX state is held on the instance rather than carried in the
TensorDict, partial resets are not supported: any call toreset()re-initialises the entire vmapped batch, ignoring the"_reset"mask. For abatch_sizegreater than one whose sub-environments terminate at different steps (e.g. early-terminating locomotion tasks driven by a data collector), prefer scaling withnum_workers(one scalar env per worker) over a single large vmappedbatch_size. This matches the behaviour ofBraxWrapper.Note
terminatedis set equal todone; this wrapper does not expose a separate time-limittruncatedsignal. For finite-horizon tasks where bootstrapping at the episode boundary matters, append aStepCounter(withmax_steps) or otherwise track truncations yourself.Examples
>>> from mujoco_playground import dm_control_suite >>> from torchrl.envs import MujocoPlaygroundWrapper >>> import torch >>> device = "cuda" if torch.cuda.is_available() else "cpu" >>> base_env = dm_control_suite.load("CartpoleBalance") >>> env = MujocoPlaygroundWrapper(base_env, device=device) >>> env.set_seed(0) >>> td = env.reset() >>> td["action"] = env.action_spec.rand() >>> td = env.step(td) >>> print(td) TensorDict( fields={ action: Tensor(torch.Size([1]), dtype=torch.float32), done: Tensor(torch.Size([1]), dtype=torch.bool), next: TensorDict( fields={ done: Tensor(torch.Size([1]), dtype=torch.bool), observation: Tensor(torch.Size([5]), dtype=torch.float32), reward: Tensor(torch.Size([1]), dtype=torch.float32), terminated: Tensor(torch.Size([1]), dtype=torch.bool)}, batch_size=torch.Size([]), device=cpu, is_shared=False), observation: Tensor(torch.Size([5]), dtype=torch.float32), terminated: Tensor(torch.Size([1]), dtype=torch.bool)}, batch_size=torch.Size([]), device=cpu, is_shared=False) >>> print(env.available_envs) ['AcrobotSwingup', 'AcrobotSwingupSparse', 'BallInCupCatch', ...]