Rate this Page

TorchRL Configuration System#

TorchRL provides a powerful configuration system built on top of Hydra that enables you to easily configure and run reinforcement learning experiments. This system uses structured dataclass-based configurations that can be composed, overridden, and extended.

The advantages of using a configuration system are: - Quick and easy to get started: provide your task and let the system handle the rest - Get a glimpse of the available options and their default values in one go: python sota-implementations/ppo_trainer/train.py --help will show you all the available options and their default values - Easy to override and extend: you can override any option in the configuration file, and you can also extend the configuration file with your own custom configurations - Easy to share and reproduce: you can share your configuration file with others, and they can reproduce your results by simply running the same command. - Easy to version control: you can easily version control your configuration file

Quick Start with a Simple Example#

Let’s start with a simple example that creates a Gym environment. Here’s a minimal configuration file:

# config.yaml
defaults:
  - env@training_env: gym

training_env:
  env_name: CartPole-v1

This configuration has two main parts:

1. The defaults section

The defaults section tells Hydra which configuration groups to include. In this case:

  • env@training_env: gym means “use the ‘gym’ configuration from the ‘env’ group for the ‘training_env’ target”

This is equivalent to including a predefined configuration for Gym environments, which sets up the proper target class and default parameters.

2. The configuration override

The training_env section allows you to override or specify parameters for the selected configuration:

  • env_name: CartPole-v1 sets the specific environment name

Configuration Categories and Groups#

TorchRL organizes configurations into several categories using the @ syntax for targeted configuration:

  • env@<target>: Environment configurations (Gym, DMControl, Brax, etc.) as well as batched environments

  • transform@<target>: Transform configurations (observation/reward processing)

  • model@<target>: Model configurations (policy and value networks)

  • network@<target>: Neural network configurations (MLP, ConvNet)

  • collector@<target>: Data collection configurations

  • replay_buffer@<target>: Replay buffer configurations

  • storage@<target>: Storage backend configurations

  • sampler@<target>: Sampling strategy configurations

  • writer@<target>: Writer strategy configurations

  • trainer@<target>: Training loop configurations

  • hook@<target>: Trainer hook configurations

  • optimizer@<target>: Optimizer configurations

  • loss@<target>: Loss function configurations

  • logger@<target>: Logging configurations

The @<target> syntax allows you to assign configurations to specific locations in your config structure.

More Complex Example: Parallel Environment with Transforms#

Here’s a more complex example that creates a parallel environment with multiple transforms applied to each worker:

defaults:
  - env@training_env: batched_env
  - env@training_env.create_env_fn: transformed_env
  - env@training_env.create_env_fn.base_env: gym
  - transform@training_env.create_env_fn.transform: compose
  - transform@transform0: noop_reset
  - transform@transform1: step_counter

# Transform configurations
transform0:
  noops: 30
  random: true

transform1:
  max_steps: 200
  step_count_key: "step_count"

# Environment configuration
training_env:
  num_workers: 4
  create_env_fn:
    base_env:
      env_name: Pendulum-v1
    transform:
      transforms:
        - ${transform0}
        - ${transform1}
    _partial_: true

What this configuration creates:

This configuration builds a parallel environment with 4 workers, where each worker runs a Pendulum-v1 environment with two transforms applied:

  1. Parallel Environment Structure: - batched_env creates a parallel environment that runs multiple environment instances - num_workers: 4 means 4 parallel environment processes

  2. Individual Environment Construction (repeated for each of the 4 workers): - Base Environment: gym with env_name: Pendulum-v1 creates a Pendulum environment - Transform Layer 1: noop_reset performs 30 random no-op actions at episode start - Transform Layer 2: step_counter limits episodes to 200 steps and tracks step count - Transform Composition: compose combines both transforms into a single transformation

  3. Final Result: 4 parallel Pendulum environments, each with: - Random no-op resets (0-30 actions at start) - Maximum episode length of 200 steps - Step counting functionality

Key Configuration Concepts:

  1. Nested targeting: env@training_env.create_env_fn.base_env: gym places a gym config deep inside the structure

  2. Function factories: _partial_: true creates a function that can be called multiple times (once per worker)

  3. Transform composition: Multiple transforms are combined and applied to each environment instance

  4. Variable interpolation: ${transform0} and ${transform1} reference the separately defined transform configurations

Getting Available Options#

To explore all available configurations and their parameters, one can use the --help flag with any TorchRL script:

python sota-implementations/ppo_trainer/train.py --help

This shows all configuration groups and their options, making it easy to discover what’s available. It should print something like this:


Complete Training Example#

Here’s a complete configuration for PPO training:

defaults:
  - env@training_env: batched_env
  - env@training_env.create_env_fn: gym
  - [email protected]_model: tanh_normal
  - [email protected]_model: value
  - [email protected]_network: mlp
  - [email protected]_network: mlp
  - collector: sync
  - replay_buffer: base
  - storage: tensor
  - sampler: without_replacement
  - writer: round_robin
  - trainer: ppo
  - optimizer: adam
  - loss: ppo
  - logger: wandb

# Network configurations
networks:
  policy_network:
    out_features: 2
    in_features: 4
    num_cells: [128, 128]

  value_network:
    out_features: 1
    in_features: 4
    num_calls: [128, 128]

# Model configurations
models:
  policy_model:
    network: ${networks.policy_network}
    in_keys: ["observation"]
    out_keys: ["action"]

  value_model:
    network: ${networks.value_network}
    in_keys: ["observation"]
    out_keys: ["state_value"]

# Environment
training_env:
  num_workers: 2
  create_env_fn:
    env_name: CartPole-v1
    _partial_: true

# Training components
trainer:
  collector: ${collector}
  optimizer: ${optimizer}
  loss_module: ${loss}
  logger: ${logger}
  total_frames: 100000

collector:
  create_env_fn: ${training_env}
  policy: ${models.policy_model}
  frames_per_batch: 1024

optimizer:
  lr: 0.001

loss:
  actor_network: ${models.policy_model}
  critic_network: ${models.value_model}

logger:
  exp_name: my_experiment

Running Experiments#

Basic Usage#

# Use default configuration
python sota-implementations/ppo_trainer/train.py

# Override specific parameters
python sota-implementations/ppo_trainer/train.py optimizer.lr=0.0001

# Change environment
python sota-implementations/ppo_trainer/train.py training_env.create_env_fn.env_name=Pendulum-v1

# Use different collector
python sota-implementations/ppo_trainer/train.py collector=async

Hyperparameter Sweeps#

# Sweep over learning rates
python sota-implementations/ppo_trainer/train.py --multirun optimizer.lr=0.0001,0.001,0.01

# Multiple parameter sweep
python sota-implementations/ppo_trainer/train.py --multirun \
  optimizer.lr=0.0001,0.001 \
  training_env.num_workers=2,4,8

Custom Configuration Files#

# Use custom config file
python sota-implementations/ppo_trainer/train.py --config-name my_custom_config

Configuration Store Implementation Details#

Under the hood, TorchRL uses Hydra’s ConfigStore to register all configuration classes. This provides type safety, validation, and IDE support. The registration happens automatically when you import the configs module:

from hydra.core.config_store import ConfigStore
from torchrl.trainers.algorithms.configs import *

cs = ConfigStore.instance()

# Environments
cs.store(group="env", name="gym", node=GymEnvConfig)
cs.store(group="env", name="batched_env", node=BatchedEnvConfig)

# Models
cs.store(group="model", name="tanh_normal", node=TanhNormalModelConfig)
# ... and many more

Available Configuration Classes#

Base Classes#

ConfigBase()

Abstract base class for all configuration classes.

Environment Configurations#

EnvConfig([_partial_])

Base configuration class for environments.

BatchedEnvConfig(_partial_, create_env_fn, ...)

Configuration for batched environments.

TransformedEnvConfig([_partial_, base_env, ...])

Configuration for transformed environments.

Environment Library Configurations#

EnvLibsConfig([_partial_])

Base configuration class for environment libs.

GymEnvConfig([_partial_, env_name, ...])

Configuration for GymEnv environment.

DMControlEnvConfig([_partial_, env_name, ...])

Configuration for DMControlEnv environment.

BraxEnvConfig([_partial_, env_name, ...])

Configuration for BraxEnv environment.

HabitatEnvConfig([_partial_, env_name, ...])

Configuration for HabitatEnv environment.

IsaacGymEnvConfig([_partial_, env_name, ...])

Configuration for IsaacGymEnv environment.

JumanjiEnvConfig([_partial_, env_name, ...])

Configuration for JumanjiEnv environment.

MeltingpotEnvConfig([_partial_, env_name, ...])

Configuration for MeltingpotEnv environment.

MOGymEnvConfig([_partial_, env_name, ...])

Configuration for MOGymEnv environment.

MultiThreadedEnvConfig([_partial_, ...])

Configuration for MultiThreadedEnv environment.

OpenMLEnvConfig([_partial_, env_name, ...])

Configuration for OpenMLEnv environment.

OpenSpielEnvConfig([_partial_, env_name, ...])

Configuration for OpenSpielEnv environment.

PettingZooEnvConfig([_partial_, env_name, ...])

Configuration for PettingZooEnv environment.

RoboHiveEnvConfig([_partial_, env_name, ...])

Configuration for RoboHiveEnv environment.

SMACv2EnvConfig([_partial_, env_name, ...])

Configuration for SMACv2Env environment.

UnityMLAgentsEnvConfig([_partial_, ...])

Configuration for UnityMLAgentsEnv environment.

VmasEnvConfig(_partial_, scenario, num_envs, ...)

Configuration for VmasEnv environment.

Model and Network Configurations#

ModelConfig([_partial_, in_keys, out_keys, ...])

Parent class to configure a model.

NetworkConfig([_partial_])

Parent class to configure a network.

MLPConfig(_partial_, in_features, ...)

A class to configure a multi-layer perceptron.

ConvNetConfig(_partial_, in_features, depth, ...)

A class to configure a convolutional network.

TensorDictModuleConfig([_partial_, in_keys, ...])

A class to configure a TensorDictModule.

TanhNormalModelConfig([_partial_, in_keys, ...])

A class to configure a TanhNormal model.

ValueModelConfig([_partial_, in_keys, ...])

A class to configure a Value model.

QValueModelConfig([_partial_, in_keys, ...])

A class to configure a QValueActor model.

TanhModuleConfig([_partial_, in_keys, ...])

A class to configure a TanhModule.

TensorDictSequentialConfig([_partial_, ...])

A class to configure a TensorDictSequential.

AdditiveGaussianModuleConfig([_partial_, ...])

A class to configure an AdditiveGaussianModule.

Transform Configurations#

TransformConfig()

Base configuration class for transforms.

ComposeConfig([transforms, _target_])

Configuration for Compose transform.

NoopResetEnvConfig([noops, random, _target_])

Configuration for NoopResetEnv transform.

StepCounterConfig([max_steps, ...])

Configuration for StepCounter transform.

DoubleToFloatConfig([in_keys, out_keys, ...])

Configuration for DoubleToFloat transform.

ToTensorImageConfig([from_int, unsqueeze, ...])

Configuration for ToTensorImage transform.

ClipTransformConfig([in_keys, out_keys, ...])

Configuration for ClipTransform.

ResizeConfig([w, h, interpolation, in_keys, ...])

Configuration for Resize transform.

CenterCropConfig([height, width, in_keys, ...])

Configuration for CenterCrop transform.

CropConfig([top, left, height, width, ...])

Configuration for Crop transform.

FlattenObservationConfig([in_keys, ...])

Configuration for FlattenObservation transform.

GrayScaleConfig([in_keys, out_keys, _target_])

Configuration for GrayScale transform.

ObservationNormConfig([loc, scale, in_keys, ...])

Configuration for ObservationNorm transform.

CatFramesConfig([N, dim, in_keys, out_keys, ...])

Configuration for CatFrames transform.

RewardClippingConfig([clamp_min, clamp_max, ...])

Configuration for RewardClipping transform.

RewardScalingConfig([loc, scale, in_keys, ...])

Configuration for RewardScaling transform.

BinarizeRewardConfig([in_keys, out_keys, ...])

Configuration for BinarizeReward transform.

TargetReturnConfig([target_return, mode, ...])

Configuration for TargetReturn transform.

VecNormConfig([in_keys, out_keys, decay, ...])

Configuration for VecNorm transform.

FrameSkipTransformConfig([frame_skip, ...])

Configuration for FrameSkipTransform.

DeviceCastTransformConfig([device, in_keys, ...])

Configuration for DeviceCastTransform.

DTypeCastTransformConfig([dtype, in_keys, ...])

Configuration for DTypeCastTransform.

UnsqueezeTransformConfig([dim, in_keys, ...])

Configuration for UnsqueezeTransform.

SqueezeTransformConfig([dim, in_keys, ...])

Configuration for SqueezeTransform.

PermuteTransformConfig([dims, in_keys, ...])

Configuration for PermuteTransform.

CatTensorsConfig([dim, in_keys, out_keys, ...])

Configuration for CatTensors transform.

StackConfig([dim, in_keys, out_keys, _target_])

Configuration for Stack transform.

DiscreteActionProjectionConfig([...])

Configuration for DiscreteActionProjection transform.

TensorDictPrimerConfig([primer_spec, ...])

Configuration for TensorDictPrimer transform.

PinMemoryTransformConfig([in_keys, ...])

Configuration for PinMemoryTransform.

RewardSumConfig([in_keys, out_keys, ...])

Configuration for RewardSum transform.

ExcludeTransformConfig([exclude_keys, _target_])

Configuration for ExcludeTransform.

SelectTransformConfig([include_keys, _target_])

Configuration for SelectTransform.

TimeMaxPoolConfig([dim, in_keys, out_keys, ...])

Configuration for TimeMaxPool transform.

RandomCropTensorDictConfig([crop_size, ...])

Configuration for RandomCropTensorDict transform.

InitTrackerConfig([init_key, _target_])

Configuration for InitTracker transform.

RenameTransformConfig([key_mapping, _target_])

Configuration for RenameTransform.

Reward2GoTransformConfig([gamma, in_keys, ...])

Configuration for Reward2GoTransform.

ActionMaskConfig([mask_key, in_keys, ...])

Configuration for ActionMask transform.

VecGymEnvTransformConfig([in_keys, ...])

Configuration for VecGymEnvTransform.

BurnInTransformConfig([burn_in, in_keys, ...])

Configuration for BurnInTransform.

SignTransformConfig([in_keys, out_keys, ...])

Configuration for SignTransform.

RemoveEmptySpecsConfig([_target_])

Configuration for RemoveEmptySpecs transform.

BatchSizeTransformConfig([batch_size, ...])

Configuration for BatchSizeTransform.

AutoResetTransformConfig([replace, ...])

Configuration for AutoResetTransform.

ActionDiscretizerConfig([num_intervals, ...])

Configuration for ActionDiscretizer transform.

TrajCounterConfig([out_key, repeats, _target_])

Configuration for TrajCounter transform.

LineariseRewardsConfig([in_keys, out_keys, ...])

Configuration for LineariseRewards transform.

ConditionalSkipConfig([cond, _target_])

Configuration for ConditionalSkip transform.

MultiActionConfig([dim, stack_rewards, ...])

Configuration for MultiAction transform.

TimerConfig([out_keys, time_key, _target_])

Configuration for Timer transform.

ConditionalPolicySwitchConfig([policy, ...])

Configuration for ConditionalPolicySwitch transform.

FiniteTensorDictCheckConfig([in_keys, ...])

Configuration for FiniteTensorDictCheck transform.

UnaryTransformConfig([fn, in_keys, ...])

Configuration for UnaryTransform.

HashConfig([in_keys, out_keys, _target_])

Configuration for Hash transform.

TokenizerConfig([vocab_size, in_keys, ...])

Configuration for Tokenizer transform.

EndOfLifeTransformConfig([eol_key, ...])

Configuration for EndOfLifeTransform.

MultiStepTransformConfig([n_steps, gamma, ...])

Configuration for MultiStepTransform.

KLRewardTransformConfig([in_keys, out_keys, ...])

Configuration for KLRewardTransform.

R3MTransformConfig([in_keys, out_keys, ...])

Configuration for R3MTransform.

VC1TransformConfig([in_keys, out_keys, ...])

Configuration for VC1Transform.

VIPTransformConfig([in_keys, out_keys, ...])

Configuration for VIPTransform.

VIPRewardTransformConfig([in_keys, ...])

Configuration for VIPRewardTransform.

VecNormV2Config([in_keys, out_keys, decay, ...])

Configuration for VecNormV2 transform.

Data Collection Configurations#

CollectorConfig([create_env_fn, policy, ...])

Hydra configuration for Collector.

AsyncCollectorConfig(create_env_fn, policy, ...)

Hydra configuration for AsyncCollector.

MultiSyncCollectorConfig([create_env_fn, ...])

Hydra configuration for MultiSyncCollector.

MultiAsyncCollectorConfig([create_env_fn, ...])

Hydra configuration for MultiAsyncCollector.

Replay Buffer and Storage Configurations#

ReplayBufferConfig([_partial_, _target_, ...])

Hydra configuration for ReplayBuffer.

TensorDictReplayBufferConfig([_partial_, ...])

Hydra configuration for TensorDictReplayBuffer.

RandomSamplerConfig([_target_])

Configuration for random sampling from replay buffer.

SamplerWithoutReplacementConfig([_target_, ...])

Configuration for sampling without replacement.

PrioritizedSamplerConfig([_target_, ...])

Configuration for prioritized sampling from replay buffer.

SliceSamplerConfig([_target_, num_slices, ...])

Configuration for slice sampling from replay buffer.

SliceSamplerWithoutReplacementConfig([...])

Configuration for slice sampling without replacement.

ListStorageConfig([_partial_, _target_, ...])

Hydra configuration for ListStorage.

TensorStorageConfig([_partial_, _target_, ...])

Configuration for tensor-based storage in replay buffer.

LazyTensorStorageConfig([_partial_, ...])

Hydra configuration for LazyTensorStorage.

LazyMemmapStorageConfig([_partial_, ...])

Hydra configuration for LazyMemmapStorage.

LazyStackStorageConfig([_partial_, ...])

Configuration for lazy stack storage.

StorageEnsembleConfig([_partial_, _target_, ...])

Configuration for storage ensemble.

RoundRobinWriterConfig([_target_, compilable])

Configuration for round-robin writer that distributes data across multiple storages.

StorageEnsembleWriterConfig([_partial_, ...])

Configuration for storage ensemble writer.

Training and Optimization Configurations#

TrainerConfig()

Base configuration class for trainers.

PPOTrainerConfig(collector, total_frames, ...)

Hydra configuration for PPOTrainer.

SACTrainerConfig(collector, total_frames, ...)

Hydra configuration for SACTrainer.

DQNTrainerConfig(collector, total_frames, ...)

Hydra configuration for DQNTrainer.

DDPGTrainerConfig(collector, total_frames, ...)

Hydra configuration for DDPGTrainer.

IQLTrainerConfig(collector, total_frames, ...)

Hydra configuration for IQLTrainer.

CQLTrainerConfig(collector, total_frames, ...)

Hydra configuration for CQLTrainer.

Trainer Hook Configurations#

HookConfig()

Base configuration class for trainer hooks.

BatchSubSamplerConfig([batch_size, ...])

Configuration for the BatchSubSampler hook.

ClearCudaCacheConfig([interval, _target_])

Configuration for the ClearCudaCache hook.

CountFramesLogConfig([frame_skip, log_pbar, ...])

Configuration for the CountFramesLog hook.

EarlyStoppingConfig([monitor, mode, ...])

Configuration for the EarlyStopping hook.

LogScalarConfig([key, logname, log_pbar, ...])

Configuration for the LogScalar hook.

LogTimingConfig([prefix, percall, erase, ...])

Configuration for the LogTiming hook.

RewardNormalizerConfig([decay, scale, eps, ...])

Configuration for the RewardNormalizer hook.

SelectKeysConfig(keys, _target_)

Configuration for the SelectKeys hook.

LossConfig([_partial_])

A class to configure a loss.

PPOLossConfig([_partial_, actor_network, ...])

Hydra configuration for the PPO loss family.

SACLossConfig([_partial_, actor_network, ...])

Hydra configuration for SACLoss (and DiscreteSACLoss when discrete=True).

DQNLossConfig([_partial_, value_network, ...])

Hydra configuration for DQNLoss.

DDPGLossConfig([_partial_, actor_network, ...])

Hydra configuration for DDPGLoss.

IQLLossConfig([_partial_, actor_network, ...])

Hydra configuration for IQLLoss (and DiscreteIQLLoss when discrete=True).

CQLLossConfig([_partial_, actor_network, ...])

Hydra configuration for CQLLoss.

GAEConfig([_partial_, gamma, lmbda, ...])

Hydra configuration for GAE.

TargetNetUpdaterConfig(loss_module[, _partial_])

An abstract class to configure target net updaters.

SoftUpdateConfig(loss_module[, _partial_, ...])

A class for soft update instantiation.

HardUpdateConfig(loss_module[, _partial_, ...])

A class for hard update instantiation.

AdamConfig([lr, betas, eps, weight_decay, ...])

Hydra configuration for torch.optim.Adam.

AdamWConfig([lr, betas, eps, weight_decay, ...])

Hydra configuration for torch.optim.AdamW.

AdamaxConfig([lr, betas, eps, weight_decay, ...])

Hydra configuration for torch.optim.Adamax.

AdadeltaConfig([lr, rho, eps, weight_decay, ...])

Hydra configuration for torch.optim.Adadelta.

AdagradConfig([lr, lr_decay, weight_decay, ...])

Hydra configuration for torch.optim.Adagrad.

ASGDConfig([lr, lambd, alpha, t0, ...])

Hydra configuration for torch.optim.ASGD.

LBFGSConfig([lr, max_iter, max_eval, ...])

Configuration for LBFGS optimizer.

LionConfig([lr, betas, weight_decay, ...])

Configuration for Lion optimizer.

NAdamConfig([lr, betas, eps, weight_decay, ...])

Hydra configuration for torch.optim.NAdam.

RAdamConfig([lr, betas, eps, weight_decay, ...])

Hydra configuration for torch.optim.RAdam.

RMSpropConfig([lr, alpha, eps, ...])

Hydra configuration for torch.optim.RMSprop.

RpropConfig([lr, etas, step_sizes, ...])

Hydra configuration for torch.optim.Rprop.

SGDConfig([lr, momentum, dampening, ...])

Hydra configuration for torch.optim.SGD.

SparseAdamConfig([lr, betas, eps, maximize, ...])

Hydra configuration for torch.optim.SparseAdam.

Logging Configurations#

LoggerConfig()

A class to configure a logger.

WandbLoggerConfig(exp_name, offline, ...)

A class to configure a Wandb logger.

TensorboardLoggerConfig(exp_name[, log_dir, ...])

A class to configure a Tensorboard logger.

TrackioLoggerConfig(exp_name, project, ...)

A class to configure a Trackio logger.

CSVLoggerConfig(exp_name[, log_dir, ...])

A class to configure a CSV logger.

Creating Custom Configurations#

You can create custom configuration classes by inheriting from the appropriate base classes:

from dataclasses import dataclass
from torchrl.trainers.algorithms.configs.envs_libs import EnvLibsConfig

@dataclass
class MyCustomEnvConfig(EnvLibsConfig):
    _target_: str = "my_module.MyCustomEnv"
    env_name: str = "MyEnv-v1"
    custom_param: float = 1.0

    def __post_init__(self):
        super().__post_init__()

# Register with ConfigStore
from hydra.core.config_store import ConfigStore
cs = ConfigStore.instance()
cs.store(group="env", name="my_custom", node=MyCustomEnvConfig)

Best Practices#

  1. Start Simple: Begin with basic configurations and gradually add complexity

  2. Use Defaults: Leverage the defaults section to compose configurations

  3. Override Sparingly: Only override what you need to change

  4. Validate Configurations: Test that your configurations instantiate correctly

  5. Version Control: Keep your configuration files under version control

  6. Use Variable Interpolation: Use ${variable} syntax to avoid duplication

Supported Algorithms#

TorchRL currently provides configuration-driven trainers for the following algorithms:

  • PPO (on-policy): PPOTrainerConfig, PPOLossConfig

  • SAC (off-policy, continuous): SACTrainerConfig, SACLossConfig

  • DQN (off-policy, discrete): DQNTrainerConfig, DQNLossConfig

  • DDPG (off-policy, continuous): DDPGTrainerConfig, DDPGLossConfig

  • IQL (offline): IQLTrainerConfig, IQLLossConfig

  • CQL (offline): CQLTrainerConfig, CQLLossConfig

The modular design ensures easy integration of additional algorithms while maintaining backward compatibility.