Shortcuts

Trainer Basics

Core trainer classes and builder utilities.

Trainer and hooks

Trainer(*args, **kwargs)

A generic Trainer class.

TrainerHookBase()

An abstract hooking class for torchrl Trainer class.

Algorithm-specific trainers

PPOTrainer(*args, **kwargs)

PPO (Proximal Policy Optimization) trainer implementation.

SACTrainer(*args, **kwargs)

A trainer class for Soft Actor-Critic (SAC) algorithm.

Builders

make_collector_offpolicy(make_env, ...[, ...])

Returns a data collector for off-policy sota-implementations.

make_collector_onpolicy(make_env, ...[, ...])

Makes a collector in on-policy settings.

make_dqn_loss(model, cfg)

Builds the DQN loss module.

make_replay_buffer(device, cfg)

Builds a replay buffer using the config built from ReplayArgsConfig.

make_target_updater(cfg, loss_module)

Builds a target network weight update object.

make_trainer(collector, loss_module[, ...])

Creates a Trainer instance given its constituents.

parallel_env_constructor(cfg, **kwargs)

Returns a parallel environment from an argparse.Namespace built with the appropriate parser constructor.

sync_async_collector(env_fns, env_kwargs[, ...])

Runs asynchronous collectors, each running synchronous environments.

sync_sync_collector(env_fns, env_kwargs[, ...])

Runs synchronous collectors, each running synchronous environments.

transformed_env_constructor(cfg[, ...])

Returns an environment creator from an argparse.Namespace built with the appropriate parser constructor.

Utils

correct_for_frame_skip(cfg)

Correct the arguments for the input frame_skip, by dividing all the arguments that reflect a count of frames by the frame_skip.

get_stats_random_rollout(cfg[, ...])

Gathers stas (loc and scale) from an environment using random rollouts.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources