Shortcuts

Source code for torchrl.trainers.algorithms.configs.envs_libs

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from dataclasses import dataclass
from typing import Any

from omegaconf import MISSING
from torchrl.envs.libs.gym import set_gym_backend
from torchrl.envs.transforms.transforms import DoubleToFloat
from torchrl.trainers.algorithms.configs.common import ConfigBase


[docs]@dataclass class EnvLibsConfig(ConfigBase): """Base configuration class for environment libs.""" _partial_: bool = False def __post_init__(self) -> None: """Post-initialization hook for environment libs configurations."""
[docs]@dataclass class GymEnvConfig(EnvLibsConfig): """Configuration for GymEnv environment.""" env_name: str = MISSING categorical_action_encoding: bool = False from_pixels: bool = False pixels_only: bool = True frame_skip: int = 1 device: str = "cpu" batch_size: list[int] | None = None allow_done_after_reset: bool = False convert_actions_to_numpy: bool = True missing_obs_value: Any = None disable_env_checker: bool | None = None render_mode: str | None = None num_envs: int = 0 backend: str = "gymnasium" _target_: str = "torchrl.trainers.algorithms.configs.envs_libs.make_gym_env" def __post_init__(self) -> None: """Post-initialization hook for GymEnv configuration.""" super().__post_init__()
def make_gym_env( env_name: str, backend: str = "gymnasium", from_pixels: bool = False, double_to_float: bool = False, **kwargs, ): """Create a Gym/Gymnasium environment. Args: env_name: Name of the environment to create. backend: Backend to use (gym or gymnasium). from_pixels: Whether to use pixel observations. double_to_float: Whether to convert double to float. Returns: The created environment instance. """ from torchrl.envs.libs.gym import GymEnv if backend is not None: with set_gym_backend(backend): env = GymEnv(env_name, from_pixels=from_pixels, **kwargs) else: env = GymEnv(env_name, from_pixels=from_pixels, **kwargs) if double_to_float: env = env.append_transform(DoubleToFloat(in_keys=["observation"])) return env
[docs]@dataclass class MOGymEnvConfig(EnvLibsConfig): """Configuration for MOGymEnv environment.""" env_name: str = MISSING categorical_action_encoding: bool = False from_pixels: bool = False pixels_only: bool = True frame_skip: int | None = None device: str = "cpu" batch_size: list[int] | None = None allow_done_after_reset: bool = False convert_actions_to_numpy: bool = True missing_obs_value: Any = None backend: str | None = None disable_env_checker: bool | None = None render_mode: str | None = None num_envs: int = 0 _target_: str = "torchrl.envs.libs.gym.MOGymEnv" def __post_init__(self) -> None: """Post-initialization hook for MOGymEnv configuration.""" super().__post_init__()
[docs]@dataclass class BraxEnvConfig(EnvLibsConfig): """Configuration for BraxEnv environment.""" env_name: str = MISSING categorical_action_encoding: bool = False cache_clear_frequency: int | None = None from_pixels: bool = False frame_skip: int | None = None device: str = "cpu" batch_size: list[int] | None = None allow_done_after_reset: bool = False requires_grad: bool = False _target_: str = "torchrl.envs.libs.brax.BraxEnv" def __post_init__(self) -> None: """Post-initialization hook for BraxEnv configuration.""" super().__post_init__()
[docs]@dataclass class DMControlEnvConfig(EnvLibsConfig): """Configuration for DMControlEnv environment.""" env_name: str = MISSING task_name: str = MISSING from_pixels: bool = False pixels_only: bool = True frame_skip: int | None = None device: str = "cpu" batch_size: list[int] | None = None allow_done_after_reset: bool = False _target_: str = "torchrl.envs.libs.dm_control.DMControlEnv" def __post_init__(self) -> None: """Post-initialization hook for DMControlEnv configuration.""" super().__post_init__()
[docs]@dataclass class HabitatEnvConfig(EnvLibsConfig): """Configuration for HabitatEnv environment.""" env_name: str = MISSING from_pixels: bool = False pixels_only: bool = True frame_skip: int | None = None device: str = "cpu" batch_size: list[int] | None = None allow_done_after_reset: bool = False _target_: str = "torchrl.envs.libs.habitat.HabitatEnv" def __post_init__(self) -> None: """Post-initialization hook for HabitatEnv configuration.""" super().__post_init__()
[docs]@dataclass class IsaacGymEnvConfig(EnvLibsConfig): """Configuration for IsaacGymEnv environment.""" env_name: str = MISSING from_pixels: bool = False pixels_only: bool = True frame_skip: int | None = None device: str = "cpu" batch_size: list[int] | None = None allow_done_after_reset: bool = False _target_: str = "torchrl.envs.libs.isaacgym.IsaacGymEnv" def __post_init__(self) -> None: """Post-initialization hook for IsaacGymEnv configuration.""" super().__post_init__()
[docs]@dataclass class JumanjiEnvConfig(EnvLibsConfig): """Configuration for JumanjiEnv environment.""" env_name: str = MISSING from_pixels: bool = False pixels_only: bool = True frame_skip: int | None = None device: str = "cpu" batch_size: list[int] | None = None allow_done_after_reset: bool = False _target_: str = "torchrl.envs.libs.jumanji.JumanjiEnv" def __post_init__(self) -> None: """Post-initialization hook for JumanjiEnv configuration.""" super().__post_init__()
[docs]@dataclass class MeltingpotEnvConfig(EnvLibsConfig): """Configuration for MeltingpotEnv environment.""" env_name: str = MISSING from_pixels: bool = False pixels_only: bool = True frame_skip: int | None = None device: str = "cpu" batch_size: list[int] | None = None allow_done_after_reset: bool = False _target_: str = "torchrl.envs.libs.meltingpot.MeltingpotEnv" def __post_init__(self) -> None: """Post-initialization hook for MeltingpotEnv configuration.""" super().__post_init__()
[docs]@dataclass class OpenMLEnvConfig(EnvLibsConfig): """Configuration for OpenMLEnv environment.""" env_name: str = MISSING from_pixels: bool = False pixels_only: bool = True frame_skip: int | None = None device: str = "cpu" batch_size: list[int] | None = None allow_done_after_reset: bool = False _target_: str = "torchrl.envs.libs.openml.OpenMLEnv" def __post_init__(self) -> None: """Post-initialization hook for OpenMLEnv configuration.""" super().__post_init__()
[docs]@dataclass class OpenSpielEnvConfig(EnvLibsConfig): """Configuration for OpenSpielEnv environment.""" env_name: str = MISSING from_pixels: bool = False pixels_only: bool = True frame_skip: int | None = None device: str = "cpu" batch_size: list[int] | None = None allow_done_after_reset: bool = False _target_: str = "torchrl.envs.libs.openspiel.OpenSpielEnv" def __post_init__(self) -> None: """Post-initialization hook for OpenSpielEnv configuration.""" super().__post_init__()
[docs]@dataclass class PettingZooEnvConfig(EnvLibsConfig): """Configuration for PettingZooEnv environment.""" env_name: str = MISSING from_pixels: bool = False pixels_only: bool = True frame_skip: int | None = None device: str = "cpu" batch_size: list[int] | None = None allow_done_after_reset: bool = False _target_: str = "torchrl.envs.libs.pettingzoo.PettingZooEnv" def __post_init__(self) -> None: """Post-initialization hook for PettingZooEnv configuration.""" super().__post_init__()
[docs]@dataclass class RoboHiveEnvConfig(EnvLibsConfig): """Configuration for RoboHiveEnv environment.""" env_name: str = MISSING from_pixels: bool = False pixels_only: bool = True frame_skip: int | None = None device: str = "cpu" batch_size: list[int] | None = None allow_done_after_reset: bool = False _target_: str = "torchrl.envs.libs.robohive.RoboHiveEnv" def __post_init__(self) -> None: """Post-initialization hook for RoboHiveEnv configuration.""" super().__post_init__()
[docs]@dataclass class SMACv2EnvConfig(EnvLibsConfig): """Configuration for SMACv2Env environment.""" env_name: str = MISSING from_pixels: bool = False pixels_only: bool = True frame_skip: int | None = None device: str = "cpu" batch_size: list[int] | None = None allow_done_after_reset: bool = False _target_: str = "torchrl.envs.libs.smacv2.SMACv2Env" def __post_init__(self) -> None: """Post-initialization hook for SMACv2Env configuration.""" super().__post_init__()
[docs]@dataclass class UnityMLAgentsEnvConfig(EnvLibsConfig): """Configuration for UnityMLAgentsEnv environment.""" env_name: str = MISSING from_pixels: bool = False pixels_only: bool = True frame_skip: int | None = None device: str = "cpu" batch_size: list[int] | None = None allow_done_after_reset: bool = False _target_: str = "torchrl.envs.libs.unity_mlagents.UnityMLAgentsEnv" def __post_init__(self) -> None: """Post-initialization hook for UnityMLAgentsEnv configuration.""" super().__post_init__()
[docs]@dataclass class VmasEnvConfig(EnvLibsConfig): """Configuration for VmasEnv environment.""" env_name: str = MISSING from_pixels: bool = False pixels_only: bool = True frame_skip: int | None = None device: str = "cpu" batch_size: list[int] | None = None allow_done_after_reset: bool = False _target_: str = "torchrl.envs.libs.vmas.VmasEnv" def __post_init__(self) -> None: """Post-initialization hook for VmasEnv configuration.""" super().__post_init__()
[docs]@dataclass class MultiThreadedEnvConfig(EnvLibsConfig): """Configuration for MultiThreadedEnv environment.""" env_name: str = MISSING from_pixels: bool = False pixels_only: bool = True frame_skip: int | None = None device: str = "cpu" batch_size: list[int] | None = None allow_done_after_reset: bool = False _target_: str = "torchrl.envs.libs.envpool.MultiThreadedEnv" def __post_init__(self) -> None: """Post-initialization hook for MultiThreadedEnv configuration.""" super().__post_init__()

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources