Shortcuts

Source code for torchrl.trainers.algorithms.configs.envs

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any

from omegaconf import MISSING

from torchrl.envs.common import EnvBase
from torchrl.trainers.algorithms.configs.common import ConfigBase


[docs]@dataclass class EnvConfig(ConfigBase): """Base configuration class for environments.""" _partial_: bool = False def __post_init__(self) -> None: """Post-initialization hook for environment configurations.""" self._partial_ = False
[docs]@dataclass class BatchedEnvConfig(EnvConfig): """Configuration for batched environments.""" create_env_fn: Any = MISSING num_workers: int = 1 create_env_kwargs: dict = field(default_factory=dict) batched_env_type: str = "parallel" device: str | None = None # batched_env_type: Literal["parallel", "serial", "async"] = "parallel" _target_: str = "torchrl.trainers.algorithms.configs.envs.make_batched_env" def __post_init__(self) -> None: """Post-initialization hook for batched environment configurations.""" super().__post_init__() if hasattr(self.create_env_fn, "_partial_"): self.create_env_fn._partial_ = True
[docs]@dataclass class TransformedEnvConfig(EnvConfig): """Configuration for transformed environments.""" base_env: Any = MISSING transform: Any = None cache_specs: bool = True auto_unwrap: bool | None = None _target_: str = "torchrl.envs.TransformedEnv"
def make_batched_env( create_env_fn, num_workers, batched_env_type="parallel", device=None, **kwargs ): """Create a batched environment. Args: create_env_fn: Function to create individual environments or environment instance. num_workers: Number of worker environments. batched_env_type: Type of batched environment (parallel, serial, async). device: Device to place the batched environment on. **kwargs: Additional keyword arguments. Returns: The created batched environment instance. """ from torchrl.envs import AsyncEnvPool, ParallelEnv, SerialEnv if create_env_fn is None: raise ValueError("create_env_fn must be provided") if num_workers is None: raise ValueError("num_workers must be provided") # If create_env_fn is a config object, create a lambda that instantiates it each time if isinstance(create_env_fn, EnvBase): # Already an instance (either instantiated config or actual env), wrap in lambda env_instance = create_env_fn def env_fn(env_instance=env_instance): return env_instance else: env_fn = create_env_fn assert callable(env_fn), env_fn # Add device to kwargs if provided if device is not None: kwargs["device"] = device if batched_env_type == "parallel": return ParallelEnv(num_workers, env_fn, **kwargs) elif batched_env_type == "serial": return SerialEnv(num_workers, env_fn, **kwargs) elif batched_env_type == "async": return AsyncEnvPool([env_fn] * num_workers, **kwargs) else: raise ValueError(f"Unknown batched_env_type: {batched_env_type}")

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources