Shortcuts

Trainer Basics

Core trainer classes and builder utilities.

Trainer and hooks

Trainer(*args, **kwargs)

A generic Trainer class.

TrainerHookBase()

An abstract hooking class for torchrl Trainer class.

Algorithm-specific trainers

PPOTrainer(*args, **kwargs)

PPO (Proximal Policy Optimization) trainer implementation.

SACTrainer(*args, **kwargs)

A trainer class for Soft Actor-Critic (SAC) algorithm.

Builders

make_collector_offpolicy(make_env, ...[, ...])

Returns a data collector for off-policy sota-implementations.

make_collector_onpolicy(make_env, ...[, ...])

Makes a collector in on-policy settings.

make_dqn_loss(model, cfg)

Builds the DQN loss module.

make_replay_buffer(device, cfg)

Builds a replay buffer using the config built from ReplayArgsConfig.

make_target_updater(cfg, loss_module)

Builds a target network weight update object.

make_trainer(collector, loss_module[, ...])

Creates a Trainer instance given its constituents.

parallel_env_constructor(cfg, **kwargs)

Returns a parallel environment from an argparse.Namespace built with the appropriate parser constructor.

sync_async_collector(env_fns, env_kwargs[, ...])

Runs asynchronous collectors, each running synchronous environments.

sync_sync_collector(env_fns, env_kwargs[, ...])

Runs synchronous collectors, each running synchronous environments.

transformed_env_constructor(cfg[, ...])

Returns an environment creator from an argparse.Namespace built with the appropriate parser constructor.

Utils

correct_for_frame_skip(cfg)

Correct the arguments for the input frame_skip, by dividing all the arguments that reflect a count of frames by the frame_skip.

get_stats_random_rollout(cfg[, ...])

Gathers stas (loc and scale) from an environment using random rollouts.

Docs

Lorem ipsum dolor sit amet, consectetur

View Docs

Tutorials

Lorem ipsum dolor sit amet, consectetur

View Tutorials

Resources

Lorem ipsum dolor sit amet, consectetur

View Resources