TorchRL Configuration System#
TorchRL provides a powerful configuration system built on top of Hydra that enables you to easily configure and run reinforcement learning experiments. This system uses structured dataclass-based configurations that can be composed, overridden, and extended.
The advantages of using a configuration system are:
- Quick and easy to get started: provide your task and let the system handle the rest
- Get a glimpse of the available options and their default values in one go: python sota-implementations/ppo_trainer/train.py --help will show you all the available options and their default values
- Easy to override and extend: you can override any option in the configuration file, and you can also extend the configuration file with your own custom configurations
- Easy to share and reproduce: you can share your configuration file with others, and they can reproduce your results by simply running the same command.
- Easy to version control: you can easily version control your configuration file
Quick Start with a Simple Example#
Let’s start with a simple example that creates a Gym environment. Here’s a minimal configuration file:
# config.yaml
defaults:
- env@training_env: gym
training_env:
env_name: CartPole-v1
This configuration has two main parts:
1. The defaults section
The defaults section tells Hydra which configuration groups to include. In this case:
env@training_env: gymmeans “use the ‘gym’ configuration from the ‘env’ group for the ‘training_env’ target”
This is equivalent to including a predefined configuration for Gym environments, which sets up the proper target class and default parameters.
2. The configuration override
The training_env section allows you to override or specify parameters for the selected configuration:
env_name: CartPole-v1sets the specific environment name
Configuration Categories and Groups#
TorchRL organizes configurations into several categories using the @ syntax for targeted configuration:
env@<target>: Environment configurations (Gym, DMControl, Brax, etc.) as well as batched environmentstransform@<target>: Transform configurations (observation/reward processing)model@<target>: Model configurations (policy and value networks)network@<target>: Neural network configurations (MLP, ConvNet)collector@<target>: Data collection configurationsreplay_buffer@<target>: Replay buffer configurationsstorage@<target>: Storage backend configurationssampler@<target>: Sampling strategy configurationswriter@<target>: Writer strategy configurationstrainer@<target>: Training loop configurationshook@<target>: Trainer hook configurationsoptimizer@<target>: Optimizer configurationsloss@<target>: Loss function configurationslogger@<target>: Logging configurations
The @<target> syntax allows you to assign configurations to specific locations in your config structure.
More Complex Example: Parallel Environment with Transforms#
Here’s a more complex example that creates a parallel environment with multiple transforms applied to each worker:
defaults:
- env@training_env: batched_env
- env@training_env.create_env_fn: transformed_env
- env@training_env.create_env_fn.base_env: gym
- transform@training_env.create_env_fn.transform: compose
- transform@transform0: noop_reset
- transform@transform1: step_counter
# Transform configurations
transform0:
noops: 30
random: true
transform1:
max_steps: 200
step_count_key: "step_count"
# Environment configuration
training_env:
num_workers: 4
create_env_fn:
base_env:
env_name: Pendulum-v1
transform:
transforms:
- ${transform0}
- ${transform1}
_partial_: true
What this configuration creates:
This configuration builds a parallel environment with 4 workers, where each worker runs a Pendulum-v1 environment with two transforms applied:
Parallel Environment Structure: -
batched_envcreates a parallel environment that runs multiple environment instances -num_workers: 4means 4 parallel environment processesIndividual Environment Construction (repeated for each of the 4 workers): - Base Environment:
gymwithenv_name: Pendulum-v1creates a Pendulum environment - Transform Layer 1:noop_resetperforms 30 random no-op actions at episode start - Transform Layer 2:step_counterlimits episodes to 200 steps and tracks step count - Transform Composition:composecombines both transforms into a single transformationFinal Result: 4 parallel Pendulum environments, each with: - Random no-op resets (0-30 actions at start) - Maximum episode length of 200 steps - Step counting functionality
Key Configuration Concepts:
Nested targeting:
env@training_env.create_env_fn.base_env: gymplaces a gym config deep inside the structureFunction factories:
_partial_: truecreates a function that can be called multiple times (once per worker)Transform composition: Multiple transforms are combined and applied to each environment instance
Variable interpolation:
${transform0}and${transform1}reference the separately defined transform configurations
Getting Available Options#
To explore all available configurations and their parameters, one can use the --help flag with any TorchRL script:
python sota-implementations/ppo_trainer/train.py --help
This shows all configuration groups and their options, making it easy to discover what’s available. It should print something like this:
Complete Training Example#
Here’s a complete configuration for PPO training:
defaults:
- env@training_env: batched_env
- env@training_env.create_env_fn: gym
- [email protected]_model: tanh_normal
- [email protected]_model: value
- [email protected]_network: mlp
- [email protected]_network: mlp
- collector: sync
- replay_buffer: base
- storage: tensor
- sampler: without_replacement
- writer: round_robin
- trainer: ppo
- optimizer: adam
- loss: ppo
- logger: wandb
# Network configurations
networks:
policy_network:
out_features: 2
in_features: 4
num_cells: [128, 128]
value_network:
out_features: 1
in_features: 4
num_calls: [128, 128]
# Model configurations
models:
policy_model:
network: ${networks.policy_network}
in_keys: ["observation"]
out_keys: ["action"]
value_model:
network: ${networks.value_network}
in_keys: ["observation"]
out_keys: ["state_value"]
# Environment
training_env:
num_workers: 2
create_env_fn:
env_name: CartPole-v1
_partial_: true
# Training components
trainer:
collector: ${collector}
optimizer: ${optimizer}
loss_module: ${loss}
logger: ${logger}
total_frames: 100000
collector:
create_env_fn: ${training_env}
policy: ${models.policy_model}
frames_per_batch: 1024
optimizer:
lr: 0.001
loss:
actor_network: ${models.policy_model}
critic_network: ${models.value_model}
logger:
exp_name: my_experiment
Running Experiments#
Basic Usage#
# Use default configuration
python sota-implementations/ppo_trainer/train.py
# Override specific parameters
python sota-implementations/ppo_trainer/train.py optimizer.lr=0.0001
# Change environment
python sota-implementations/ppo_trainer/train.py training_env.create_env_fn.env_name=Pendulum-v1
# Use different collector
python sota-implementations/ppo_trainer/train.py collector=async
Hyperparameter Sweeps#
# Sweep over learning rates
python sota-implementations/ppo_trainer/train.py --multirun optimizer.lr=0.0001,0.001,0.01
# Multiple parameter sweep
python sota-implementations/ppo_trainer/train.py --multirun \
optimizer.lr=0.0001,0.001 \
training_env.num_workers=2,4,8
Custom Configuration Files#
# Use custom config file
python sota-implementations/ppo_trainer/train.py --config-name my_custom_config
Configuration Store Implementation Details#
Under the hood, TorchRL uses Hydra’s ConfigStore to register all configuration classes. This provides type safety, validation, and IDE support. The registration happens automatically when you import the configs module:
from hydra.core.config_store import ConfigStore
from torchrl.trainers.algorithms.configs import *
cs = ConfigStore.instance()
# Environments
cs.store(group="env", name="gym", node=GymEnvConfig)
cs.store(group="env", name="batched_env", node=BatchedEnvConfig)
# Models
cs.store(group="model", name="tanh_normal", node=TanhNormalModelConfig)
# ... and many more
Available Configuration Classes#
Base Classes#
Abstract base class for all configuration classes. |
Environment Configurations#
|
Base configuration class for environments. |
|
Configuration for batched environments. |
|
Configuration for transformed environments. |
Environment Library Configurations#
|
Base configuration class for environment libs. |
|
Configuration for GymEnv environment. |
|
Configuration for DMControlEnv environment. |
|
Configuration for BraxEnv environment. |
|
Configuration for HabitatEnv environment. |
|
Configuration for IsaacGymEnv environment. |
|
Configuration for JumanjiEnv environment. |
|
Configuration for MeltingpotEnv environment. |
|
Configuration for MOGymEnv environment. |
|
Configuration for MultiThreadedEnv environment. |
|
Configuration for OpenMLEnv environment. |
|
Configuration for OpenSpielEnv environment. |
|
Configuration for PettingZooEnv environment. |
|
Configuration for RoboHiveEnv environment. |
|
Configuration for SMACv2Env environment. |
|
Configuration for UnityMLAgentsEnv environment. |
|
Configuration for VmasEnv environment. |
Model and Network Configurations#
|
Parent class to configure a model. |
|
Parent class to configure a network. |
|
A class to configure a multi-layer perceptron. |
|
A class to configure a convolutional network. |
|
A class to configure a TensorDictModule. |
|
A class to configure a TanhNormal model. |
|
A class to configure a Value model. |
|
A class to configure a QValueActor model. |
|
A class to configure a TanhModule. |
|
A class to configure a TensorDictSequential. |
|
A class to configure an AdditiveGaussianModule. |
Transform Configurations#
Base configuration class for transforms. |
|
|
Configuration for Compose transform. |
|
Configuration for NoopResetEnv transform. |
|
Configuration for StepCounter transform. |
|
Configuration for DoubleToFloat transform. |
|
Configuration for ToTensorImage transform. |
|
Configuration for ClipTransform. |
|
Configuration for Resize transform. |
|
Configuration for CenterCrop transform. |
|
Configuration for Crop transform. |
|
Configuration for FlattenObservation transform. |
|
Configuration for GrayScale transform. |
|
Configuration for ObservationNorm transform. |
|
Configuration for CatFrames transform. |
|
Configuration for RewardClipping transform. |
|
Configuration for RewardScaling transform. |
|
Configuration for BinarizeReward transform. |
|
Configuration for TargetReturn transform. |
|
Configuration for VecNorm transform. |
|
Configuration for FrameSkipTransform. |
|
Configuration for DeviceCastTransform. |
|
Configuration for DTypeCastTransform. |
|
Configuration for UnsqueezeTransform. |
|
Configuration for SqueezeTransform. |
|
Configuration for PermuteTransform. |
|
Configuration for CatTensors transform. |
|
Configuration for Stack transform. |
Configuration for DiscreteActionProjection transform. |
|
|
Configuration for TensorDictPrimer transform. |
|
Configuration for PinMemoryTransform. |
|
Configuration for RewardSum transform. |
|
Configuration for ExcludeTransform. |
|
Configuration for SelectTransform. |
|
Configuration for TimeMaxPool transform. |
|
Configuration for RandomCropTensorDict transform. |
|
Configuration for InitTracker transform. |
|
Configuration for RenameTransform. |
|
Configuration for Reward2GoTransform. |
|
Configuration for ActionMask transform. |
|
Configuration for VecGymEnvTransform. |
|
Configuration for BurnInTransform. |
|
Configuration for SignTransform. |
|
Configuration for RemoveEmptySpecs transform. |
|
Configuration for BatchSizeTransform. |
|
Configuration for AutoResetTransform. |
|
Configuration for ActionDiscretizer transform. |
|
Configuration for TrajCounter transform. |
|
Configuration for LineariseRewards transform. |
|
Configuration for ConditionalSkip transform. |
|
Configuration for MultiAction transform. |
|
Configuration for Timer transform. |
|
Configuration for ConditionalPolicySwitch transform. |
|
Configuration for FiniteTensorDictCheck transform. |
|
Configuration for UnaryTransform. |
|
Configuration for Hash transform. |
|
Configuration for Tokenizer transform. |
|
Configuration for EndOfLifeTransform. |
|
Configuration for MultiStepTransform. |
|
Configuration for KLRewardTransform. |
|
Configuration for R3MTransform. |
|
Configuration for VC1Transform. |
|
Configuration for VIPTransform. |
|
Configuration for VIPRewardTransform. |
|
Configuration for VecNormV2 transform. |
Data Collection Configurations#
|
Hydra configuration for |
|
Hydra configuration for |
|
Hydra configuration for |
|
Hydra configuration for |
Replay Buffer and Storage Configurations#
|
Hydra configuration for |
|
Hydra configuration for |
|
Configuration for random sampling from replay buffer. |
|
Configuration for sampling without replacement. |
|
Configuration for prioritized sampling from replay buffer. |
|
Configuration for slice sampling from replay buffer. |
Configuration for slice sampling without replacement. |
|
|
Hydra configuration for |
|
Configuration for tensor-based storage in replay buffer. |
|
Hydra configuration for |
|
Hydra configuration for |
|
Configuration for lazy stack storage. |
|
Configuration for storage ensemble. |
|
Configuration for round-robin writer that distributes data across multiple storages. |
|
Configuration for storage ensemble writer. |
Training and Optimization Configurations#
Base configuration class for trainers. |
|
|
Hydra configuration for |
|
Hydra configuration for |
|
Hydra configuration for |
|
Hydra configuration for |
|
Hydra configuration for |
|
Hydra configuration for |
Trainer Hook Configurations#
Base configuration class for trainer hooks. |
|
|
Configuration for the |
|
Configuration for the |
|
Configuration for the |
|
Configuration for the |
|
Configuration for the |
|
Configuration for the |
|
Configuration for the |
|
Configuration for the |
|
A class to configure a loss. |
|
Hydra configuration for the PPO loss family. |
|
Hydra configuration for |
|
Hydra configuration for |
|
Hydra configuration for |
|
Hydra configuration for |
|
Hydra configuration for |
|
Hydra configuration for |
|
An abstract class to configure target net updaters. |
|
A class for soft update instantiation. |
|
A class for hard update instantiation. |
|
Hydra configuration for |
|
Hydra configuration for |
|
Hydra configuration for |
|
Hydra configuration for |
|
Hydra configuration for |
|
Hydra configuration for |
|
Configuration for LBFGS optimizer. |
|
Configuration for Lion optimizer. |
|
Hydra configuration for |
|
Hydra configuration for |
|
Hydra configuration for |
|
Hydra configuration for |
|
Hydra configuration for |
|
Hydra configuration for |
Logging Configurations#
A class to configure a logger. |
|
|
A class to configure a Wandb logger. |
|
A class to configure a Tensorboard logger. |
|
A class to configure a Trackio logger. |
|
A class to configure a CSV logger. |
Creating Custom Configurations#
You can create custom configuration classes by inheriting from the appropriate base classes:
from dataclasses import dataclass
from torchrl.trainers.algorithms.configs.envs_libs import EnvLibsConfig
@dataclass
class MyCustomEnvConfig(EnvLibsConfig):
_target_: str = "my_module.MyCustomEnv"
env_name: str = "MyEnv-v1"
custom_param: float = 1.0
def __post_init__(self):
super().__post_init__()
# Register with ConfigStore
from hydra.core.config_store import ConfigStore
cs = ConfigStore.instance()
cs.store(group="env", name="my_custom", node=MyCustomEnvConfig)
Best Practices#
Start Simple: Begin with basic configurations and gradually add complexity
Use Defaults: Leverage the
defaultssection to compose configurationsOverride Sparingly: Only override what you need to change
Validate Configurations: Test that your configurations instantiate correctly
Version Control: Keep your configuration files under version control
Use Variable Interpolation: Use
${variable}syntax to avoid duplication
Supported Algorithms#
TorchRL currently provides configuration-driven trainers for the following algorithms:
PPO (on-policy):
PPOTrainerConfig,PPOLossConfigSAC (off-policy, continuous):
SACTrainerConfig,SACLossConfigDQN (off-policy, discrete):
DQNTrainerConfig,DQNLossConfigDDPG (off-policy, continuous):
DDPGTrainerConfig,DDPGLossConfigIQL (offline):
IQLTrainerConfig,IQLLossConfigCQL (offline):
CQLTrainerConfig,CQLLossConfig
The modular design ensures easy integration of additional algorithms while maintaining backward compatibility.