Shortcuts

TorchRL Configuration System

TorchRL provides a powerful configuration system built on top of Hydra that enables you to easily configure and run reinforcement learning experiments. This system uses structured dataclass-based configurations that can be composed, overridden, and extended.

The advantages of using a configuration system are: - Quick and easy to get started: provide your task and let the system handle the rest - Get a glimpse of the available options and their default values in one go: python sota-implementations/ppo_trainer/train.py --help will show you all the available options and their default values - Easy to override and extend: you can override any option in the configuration file, and you can also extend the configuration file with your own custom configurations - Easy to share and reproduce: you can share your configuration file with others, and they can reproduce your results by simply running the same command. - Easy to version control: you can easily version control your configuration file

Quick Start with a Simple Example

Let’s start with a simple example that creates a Gym environment. Here’s a minimal configuration file:

# config.yaml
defaults:
  - env@training_env: gym

training_env:
  env_name: CartPole-v1

This configuration has two main parts:

1. The defaults section

The defaults section tells Hydra which configuration groups to include. In this case:

  • env@training_env: gym means “use the ‘gym’ configuration from the ‘env’ group for the ‘training_env’ target”

This is equivalent to including a predefined configuration for Gym environments, which sets up the proper target class and default parameters.

2. The configuration override

The training_env section allows you to override or specify parameters for the selected configuration:

  • env_name: CartPole-v1 sets the specific environment name

Configuration Categories and Groups

TorchRL organizes configurations into several categories using the @ syntax for targeted configuration:

  • env@<target>: Environment configurations (Gym, DMControl, Brax, etc.) as well as batched environments

  • transform@<target>: Transform configurations (observation/reward processing)

  • model@<target>: Model configurations (policy and value networks)

  • network@<target>: Neural network configurations (MLP, ConvNet)

  • collector@<target>: Data collection configurations

  • replay_buffer@<target>: Replay buffer configurations

  • storage@<target>: Storage backend configurations

  • sampler@<target>: Sampling strategy configurations

  • writer@<target>: Writer strategy configurations

  • trainer@<target>: Training loop configurations

  • optimizer@<target>: Optimizer configurations

  • loss@<target>: Loss function configurations

  • logger@<target>: Logging configurations

The @<target> syntax allows you to assign configurations to specific locations in your config structure.

More Complex Example: Parallel Environment with Transforms

Here’s a more complex example that creates a parallel environment with multiple transforms applied to each worker:

defaults:
  - env@training_env: batched_env
  - env@training_env.create_env_fn: transformed_env
  - env@training_env.create_env_fn.base_env: gym
  - transform@training_env.create_env_fn.transform: compose
  - transform@transform0: noop_reset
  - transform@transform1: step_counter

# Transform configurations
transform0:
  noops: 30
  random: true

transform1:
  max_steps: 200
  step_count_key: "step_count"

# Environment configuration
training_env:
  num_workers: 4
  create_env_fn:
    base_env:
      env_name: Pendulum-v1
    transform:
      transforms:
        - ${transform0}
        - ${transform1}
    _partial_: true

What this configuration creates:

This configuration builds a parallel environment with 4 workers, where each worker runs a Pendulum-v1 environment with two transforms applied:

  1. Parallel Environment Structure: - batched_env creates a parallel environment that runs multiple environment instances - num_workers: 4 means 4 parallel environment processes

  2. Individual Environment Construction (repeated for each of the 4 workers): - Base Environment: gym with env_name: Pendulum-v1 creates a Pendulum environment - Transform Layer 1: noop_reset performs 30 random no-op actions at episode start - Transform Layer 2: step_counter limits episodes to 200 steps and tracks step count - Transform Composition: compose combines both transforms into a single transformation

  3. Final Result: 4 parallel Pendulum environments, each with: - Random no-op resets (0-30 actions at start) - Maximum episode length of 200 steps - Step counting functionality

Key Configuration Concepts:

  1. Nested targeting: env@training_env.create_env_fn.base_env: gym places a gym config deep inside the structure

  2. Function factories: _partial_: true creates a function that can be called multiple times (once per worker)

  3. Transform composition: Multiple transforms are combined and applied to each environment instance

  4. Variable interpolation: ${transform0} and ${transform1} reference the separately defined transform configurations

Getting Available Options

To explore all available configurations and their parameters, one can use the --help flag with any TorchRL script:

python sota-implementations/ppo_trainer/train.py --help

This shows all configuration groups and their options, making it easy to discover what’s available. It should print something like this:


Complete Training Example

Here’s a complete configuration for PPO training:

defaults:
  - env@training_env: batched_env
  - env@training_env.create_env_fn: gym
  - model@models.policy_model: tanh_normal
  - model@models.value_model: value
  - network@networks.policy_network: mlp
  - network@networks.value_network: mlp
  - collector: sync
  - replay_buffer: base
  - storage: tensor
  - sampler: without_replacement
  - writer: round_robin
  - trainer: ppo
  - optimizer: adam
  - loss: ppo
  - logger: wandb

# Network configurations
networks:
  policy_network:
    out_features: 2
    in_features: 4
    num_cells: [128, 128]

  value_network:
    out_features: 1
    in_features: 4
    num_calls: [128, 128]

# Model configurations
models:
  policy_model:
    network: ${networks.policy_network}
    in_keys: ["observation"]
    out_keys: ["action"]

  value_model:
    network: ${networks.value_network}
    in_keys: ["observation"]
    out_keys: ["state_value"]

# Environment
training_env:
  num_workers: 2
  create_env_fn:
    env_name: CartPole-v1
    _partial_: true

# Training components
trainer:
  collector: ${collector}
  optimizer: ${optimizer}
  loss_module: ${loss}
  logger: ${logger}
  total_frames: 100000

collector:
  create_env_fn: ${training_env}
  policy: ${models.policy_model}
  frames_per_batch: 1024

optimizer:
  lr: 0.001

loss:
  actor_network: ${models.policy_model}
  critic_network: ${models.value_model}

logger:
  exp_name: my_experiment

Running Experiments

Basic Usage

# Use default configuration
python sota-implementations/ppo_trainer/train.py

# Override specific parameters
python sota-implementations/ppo_trainer/train.py optimizer.lr=0.0001

# Change environment
python sota-implementations/ppo_trainer/train.py training_env.create_env_fn.env_name=Pendulum-v1

# Use different collector
python sota-implementations/ppo_trainer/train.py collector=async

Hyperparameter Sweeps

# Sweep over learning rates
python sota-implementations/ppo_trainer/train.py --multirun optimizer.lr=0.0001,0.001,0.01

# Multiple parameter sweep
python sota-implementations/ppo_trainer/train.py --multirun \
  optimizer.lr=0.0001,0.001 \
  training_env.num_workers=2,4,8

Custom Configuration Files

# Use custom config file
python sota-implementations/ppo_trainer/train.py --config-name my_custom_config

Configuration Store Implementation Details

Under the hood, TorchRL uses Hydra’s ConfigStore to register all configuration classes. This provides type safety, validation, and IDE support. The registration happens automatically when you import the configs module:

from hydra.core.config_store import ConfigStore
from torchrl.trainers.algorithms.configs import *

cs = ConfigStore.instance()

# Environments
cs.store(group="env", name="gym", node=GymEnvConfig)
cs.store(group="env", name="batched_env", node=BatchedEnvConfig)

# Models
cs.store(group="model", name="tanh_normal", node=TanhNormalModelConfig)
# ... and many more

Available Configuration Classes

Base Classes

ConfigBase()

Abstract base class for all configuration classes.

Environment Configurations

EnvConfig([_partial_])

Base configuration class for environments.

BatchedEnvConfig(_partial_, create_env_fn, ...)

Configuration for batched environments.

TransformedEnvConfig([_partial_, base_env, ...])

Configuration for transformed environments.

Environment Library Configurations

EnvLibsConfig([_partial_])

Base configuration class for environment libs.

GymEnvConfig([_partial_, env_name, ...])

Configuration for GymEnv environment.

DMControlEnvConfig([_partial_, env_name, ...])

Configuration for DMControlEnv environment.

BraxEnvConfig([_partial_, env_name, ...])

Configuration for BraxEnv environment.

HabitatEnvConfig([_partial_, env_name, ...])

Configuration for HabitatEnv environment.

IsaacGymEnvConfig([_partial_, env_name, ...])

Configuration for IsaacGymEnv environment.

JumanjiEnvConfig([_partial_, env_name, ...])

Configuration for JumanjiEnv environment.

MeltingpotEnvConfig([_partial_, env_name, ...])

Configuration for MeltingpotEnv environment.

MOGymEnvConfig([_partial_, env_name, ...])

Configuration for MOGymEnv environment.

MultiThreadedEnvConfig([_partial_, ...])

Configuration for MultiThreadedEnv environment.

OpenMLEnvConfig([_partial_, env_name, ...])

Configuration for OpenMLEnv environment.

OpenSpielEnvConfig([_partial_, env_name, ...])

Configuration for OpenSpielEnv environment.

PettingZooEnvConfig([_partial_, env_name, ...])

Configuration for PettingZooEnv environment.

RoboHiveEnvConfig([_partial_, env_name, ...])

Configuration for RoboHiveEnv environment.

SMACv2EnvConfig([_partial_, env_name, ...])

Configuration for SMACv2Env environment.

UnityMLAgentsEnvConfig([_partial_, ...])

Configuration for UnityMLAgentsEnv environment.

VmasEnvConfig([_partial_, env_name, ...])

Configuration for VmasEnv environment.

Model and Network Configurations

ModelConfig([_partial_, in_keys, out_keys])

Parent class to configure a model.

NetworkConfig([_partial_])

Parent class to configure a network.

MLPConfig(_partial_, in_features, ...)

A class to configure a multi-layer perceptron.

ConvNetConfig(_partial_, in_features, depth, ...)

A class to configure a convolutional network.

TensorDictModuleConfig([_partial_, in_keys, ...])

A class to configure a TensorDictModule.

TanhNormalModelConfig([_partial_, in_keys, ...])

A class to configure a TanhNormal model.

ValueModelConfig([_partial_, in_keys, ...])

A class to configure a Value model.

Transform Configurations

TransformConfig()

Base configuration class for transforms.

ComposeConfig([transforms, _target_])

Configuration for Compose transform.

NoopResetEnvConfig([noops, random, _target_])

Configuration for NoopResetEnv transform.

StepCounterConfig([max_steps, ...])

Configuration for StepCounter transform.

DoubleToFloatConfig([in_keys, out_keys, ...])

Configuration for DoubleToFloat transform.

ToTensorImageConfig([from_int, unsqueeze, ...])

Configuration for ToTensorImage transform.

ClipTransformConfig([in_keys, out_keys, ...])

Configuration for ClipTransform.

ResizeConfig([w, h, interpolation, in_keys, ...])

Configuration for Resize transform.

CenterCropConfig([height, width, in_keys, ...])

Configuration for CenterCrop transform.

CropConfig([top, left, height, width, ...])

Configuration for Crop transform.

FlattenObservationConfig([in_keys, ...])

Configuration for FlattenObservation transform.

GrayScaleConfig([in_keys, out_keys, _target_])

Configuration for GrayScale transform.

ObservationNormConfig([loc, scale, in_keys, ...])

Configuration for ObservationNorm transform.

CatFramesConfig([N, dim, in_keys, out_keys, ...])

Configuration for CatFrames transform.

RewardClippingConfig([clamp_min, clamp_max, ...])

Configuration for RewardClipping transform.

RewardScalingConfig([loc, scale, in_keys, ...])

Configuration for RewardScaling transform.

BinarizeRewardConfig([in_keys, out_keys, ...])

Configuration for BinarizeReward transform.

TargetReturnConfig([target_return, mode, ...])

Configuration for TargetReturn transform.

VecNormConfig([in_keys, out_keys, decay, ...])

Configuration for VecNorm transform.

FrameSkipTransformConfig([frame_skip, ...])

Configuration for FrameSkipTransform.

DeviceCastTransformConfig([device, in_keys, ...])

Configuration for DeviceCastTransform.

DTypeCastTransformConfig([dtype, in_keys, ...])

Configuration for DTypeCastTransform.

UnsqueezeTransformConfig([dim, in_keys, ...])

Configuration for UnsqueezeTransform.

SqueezeTransformConfig([dim, in_keys, ...])

Configuration for SqueezeTransform.

PermuteTransformConfig([dims, in_keys, ...])

Configuration for PermuteTransform.

CatTensorsConfig([dim, in_keys, out_keys, ...])

Configuration for CatTensors transform.

StackConfig([dim, in_keys, out_keys, _target_])

Configuration for Stack transform.

DiscreteActionProjectionConfig([...])

Configuration for DiscreteActionProjection transform.

TensorDictPrimerConfig([primer_spec, ...])

Configuration for TensorDictPrimer transform.

PinMemoryTransformConfig([in_keys, ...])

Configuration for PinMemoryTransform.

RewardSumConfig([in_keys, out_keys, _target_])

Configuration for RewardSum transform.

ExcludeTransformConfig([exclude_keys, _target_])

Configuration for ExcludeTransform.

SelectTransformConfig([include_keys, _target_])

Configuration for SelectTransform.

TimeMaxPoolConfig([dim, in_keys, out_keys, ...])

Configuration for TimeMaxPool transform.

RandomCropTensorDictConfig([crop_size, ...])

Configuration for RandomCropTensorDict transform.

InitTrackerConfig([in_keys, out_keys, _target_])

Configuration for InitTracker transform.

RenameTransformConfig([key_mapping, _target_])

Configuration for RenameTransform.

Reward2GoTransformConfig([gamma, in_keys, ...])

Configuration for Reward2GoTransform.

ActionMaskConfig([mask_key, in_keys, ...])

Configuration for ActionMask transform.

VecGymEnvTransformConfig([in_keys, ...])

Configuration for VecGymEnvTransform.

BurnInTransformConfig([burn_in, in_keys, ...])

Configuration for BurnInTransform.

SignTransformConfig([in_keys, out_keys, ...])

Configuration for SignTransform.

RemoveEmptySpecsConfig([_target_])

Configuration for RemoveEmptySpecs transform.

BatchSizeTransformConfig([batch_size, ...])

Configuration for BatchSizeTransform.

AutoResetTransformConfig([replace, ...])

Configuration for AutoResetTransform.

ActionDiscretizerConfig([num_intervals, ...])

Configuration for ActionDiscretizer transform.

TrajCounterConfig([out_key, repeats, _target_])

Configuration for TrajCounter transform.

LineariseRewardsConfig([in_keys, out_keys, ...])

Configuration for LineariseRewards transform.

ConditionalSkipConfig([cond, _target_])

Configuration for ConditionalSkip transform.

MultiActionConfig([dim, stack_rewards, ...])

Configuration for MultiAction transform.

TimerConfig([out_keys, time_key, _target_])

Configuration for Timer transform.

ConditionalPolicySwitchConfig([policy, ...])

Configuration for ConditionalPolicySwitch transform.

FiniteTensorDictCheckConfig([in_keys, ...])

Configuration for FiniteTensorDictCheck transform.

UnaryTransformConfig([fn, in_keys, ...])

Configuration for UnaryTransform.

HashConfig([in_keys, out_keys, _target_])

Configuration for Hash transform.

TokenizerConfig([vocab_size, in_keys, ...])

Configuration for Tokenizer transform.

EndOfLifeTransformConfig([eol_key, ...])

Configuration for EndOfLifeTransform.

MultiStepTransformConfig([n_steps, gamma, ...])

Configuration for MultiStepTransform.

KLRewardTransformConfig([in_keys, out_keys, ...])

Configuration for KLRewardTransform.

R3MTransformConfig([in_keys, out_keys, ...])

Configuration for R3MTransform.

VC1TransformConfig([in_keys, out_keys, ...])

Configuration for VC1Transform.

VIPTransformConfig([in_keys, out_keys, ...])

Configuration for VIPTransform.

VIPRewardTransformConfig([in_keys, ...])

Configuration for VIPRewardTransform.

VecNormV2Config([in_keys, out_keys, decay, ...])

Configuration for VecNormV2 transform.

Data Collection Configurations

DataCollectorConfig()

Parent class to configure a data collector.

SyncDataCollectorConfig([create_env_fn, ...])

A class to configure a synchronous data collector.

AsyncDataCollectorConfig(create_env_fn, ...)

Configuration for asynchronous data collector.

MultiSyncDataCollectorConfig([...])

Configuration for multi-synchronous data collector.

MultiaSyncDataCollectorConfig([...])

Configuration for multi-asynchronous data collector.

Replay Buffer and Storage Configurations

ReplayBufferConfig([_partial_, _target_, ...])

Configuration for generic replay buffer.

TensorDictReplayBufferConfig([_partial_, ...])

Configuration for TensorDict-based replay buffer.

RandomSamplerConfig([_target_])

Configuration for random sampling from replay buffer.

SamplerWithoutReplacementConfig([_target_, ...])

Configuration for sampling without replacement.

PrioritizedSamplerConfig([_target_, ...])

Configuration for prioritized sampling from replay buffer.

SliceSamplerConfig([_target_, num_slices, ...])

Configuration for slice sampling from replay buffer.

SliceSamplerWithoutReplacementConfig([...])

Configuration for slice sampling without replacement.

ListStorageConfig([_partial_, _target_, ...])

Configuration for list-based storage in replay buffer.

TensorStorageConfig([_partial_, _target_, ...])

Configuration for tensor-based storage in replay buffer.

LazyTensorStorageConfig([_partial_, ...])

Configuration for lazy tensor storage.

LazyMemmapStorageConfig([_partial_, ...])

Configuration for lazy memory-mapped storage.

LazyStackStorageConfig([_partial_, ...])

Configuration for lazy stack storage.

StorageEnsembleConfig([_partial_, _target_, ...])

Configuration for storage ensemble.

RoundRobinWriterConfig([_target_, compilable])

Configuration for round-robin writer that distributes data across multiple storages.

StorageEnsembleWriterConfig([_partial_, ...])

Configuration for storage ensemble writer.

Training and Optimization Configurations

TrainerConfig()

Base configuration class for trainers.

PPOTrainerConfig(collector, total_frames, ...)

Configuration class for PPO (Proximal Policy Optimization) trainer.

LossConfig([_partial_])

A class to configure a loss.

PPOLossConfig([_partial_, actor_network, ...])

A class to configure a PPO loss.

AdamConfig([lr, betas, eps, weight_decay, ...])

Configuration for Adam optimizer.

AdamWConfig([lr, betas, eps, weight_decay, ...])

Configuration for AdamW optimizer.

AdamaxConfig([lr, betas, eps, weight_decay, ...])

Configuration for Adamax optimizer.

AdadeltaConfig([lr, rho, eps, weight_decay, ...])

Configuration for Adadelta optimizer.

AdagradConfig([lr, lr_decay, weight_decay, ...])

Configuration for Adagrad optimizer.

ASGDConfig([lr, lambd, alpha, t0, ...])

Configuration for ASGD optimizer.

LBFGSConfig([lr, max_iter, max_eval, ...])

Configuration for LBFGS optimizer.

LionConfig([lr, betas, weight_decay, ...])

Configuration for Lion optimizer.

NAdamConfig([lr, betas, eps, weight_decay, ...])

Configuration for NAdam optimizer.

RAdamConfig([lr, betas, eps, weight_decay, ...])

Configuration for RAdam optimizer.

RMSpropConfig([lr, alpha, eps, ...])

Configuration for RMSprop optimizer.

RpropConfig([lr, etas, step_sizes, foreach, ...])

Configuration for Rprop optimizer.

SGDConfig([lr, momentum, dampening, ...])

Configuration for SGD optimizer.

SparseAdamConfig([lr, betas, eps, _target_, ...])

Configuration for SparseAdam optimizer.

Logging Configurations

LoggerConfig()

A class to configure a logger.

WandbLoggerConfig(exp_name[, offline, ...])

A class to configure a Wandb logger.

TensorboardLoggerConfig(exp_name[, log_dir, ...])

A class to configure a Tensorboard logger.

CSVLoggerConfig(exp_name[, log_dir, ...])

A class to configure a CSV logger.

Creating Custom Configurations

You can create custom configuration classes by inheriting from the appropriate base classes:

from dataclasses import dataclass
from torchrl.trainers.algorithms.configs.envs_libs import EnvLibsConfig

@dataclass
class MyCustomEnvConfig(EnvLibsConfig):
    _target_: str = "my_module.MyCustomEnv"
    env_name: str = "MyEnv-v1"
    custom_param: float = 1.0

    def __post_init__(self):
        super().__post_init__()

# Register with ConfigStore
from hydra.core.config_store import ConfigStore
cs = ConfigStore.instance()
cs.store(group="env", name="my_custom", node=MyCustomEnvConfig)

Best Practices

  1. Start Simple: Begin with basic configurations and gradually add complexity

  2. Use Defaults: Leverage the defaults section to compose configurations

  3. Override Sparingly: Only override what you need to change

  4. Validate Configurations: Test that your configurations instantiate correctly

  5. Version Control: Keep your configuration files under version control

  6. Use Variable Interpolation: Use ${variable} syntax to avoid duplication

Future Extensions

As TorchRL adds more algorithms beyond PPO (such as SAC, TD3, DQN), the configuration system will expand with:

  • New trainer configurations (e.g., SACTrainerConfig, TD3TrainerConfig)

  • Algorithm-specific loss configurations

  • Specialized collector configurations for different algorithms

  • Additional environment and model configurations

The modular design ensures easy integration while maintaining backward compatibility.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources