Rate this Page

Trainer Basics#

Core trainer classes and builder utilities.

Trainer and hooks#

Trainer(*args, **kwargs)

A generic Trainer class.

TrainerHookBase()

An abstract hooking class for torchrl Trainer class.

Algorithm-specific trainers#

PPOTrainer(*args, **kwargs)

PPO (Proximal Policy Optimization) trainer implementation.

SACTrainer(*args, **kwargs)

A trainer class for Soft Actor-Critic (SAC) algorithm.

DQNTrainer(*args, **kwargs)

A trainer class for Deep Q-Network (DQN) algorithm.

DDPGTrainer(*args, **kwargs)

A trainer class for Deep Deterministic Policy Gradient (DDPG) algorithm.

IQLTrainer(*args, **kwargs)

A trainer class for Implicit Q-Learning (IQL) algorithm.

CQLTrainer(*args, **kwargs)

A trainer class for Conservative Q-Learning (CQL) algorithm.

Builders#

make_collector_offpolicy(make_env, ...[, ...])

Returns a data collector for off-policy sota-implementations.

make_collector_onpolicy(make_env, ...[, ...])

Makes a collector in on-policy settings.

make_dqn_loss(model, cfg)

Builds the DQN loss module.

make_replay_buffer(device, cfg)

Builds a replay buffer using the config built from ReplayArgsConfig.

make_target_updater(cfg, loss_module)

Builds a target network weight update object.

make_trainer(collector, loss_module[, ...])

Creates a Trainer instance given its constituents.

parallel_env_constructor(cfg, **kwargs)

Returns a parallel environment from an argparse.Namespace built with the appropriate parser constructor.

sync_async_collector(env_fns, env_kwargs[, ...])

Runs asynchronous collectors, each running synchronous environments.

sync_sync_collector(env_fns, env_kwargs[, ...])

Runs synchronous collectors, each running synchronous environments.

transformed_env_constructor(cfg[, ...])

Returns an environment creator from an argparse.Namespace built with the appropriate parser constructor.

Utils#

correct_for_frame_skip(cfg)

Correct the arguments for the input frame_skip, by dividing all the arguments that reflect a count of frames by the frame_skip.

get_stats_random_rollout(cfg[, ...])

Gathers stas (loc and scale) from an environment using random rollouts.