Shortcuts

Training Hooks

Hooks for customizing the training loop at various points.

BatchSubSampler(batch_size[, sub_traj_len, ...])

Data subsampler for online RL sota-implementations.

ClearCudaCache(interval)

Clears cuda cache at a given interval.

CountFramesLog(*args, **kwargs)

A frame counter hook.

LogScalar([key, logname, log_pbar, ...])

Generic scalar logger hook for any tensor values in the batch.

OptimizerHook(optimizer[, loss_components])

Add an optimizer for one or more loss components.

LogValidationReward(*, record_interval, ...)

Recorder hook for Trainer.

ReplayBufferTrainer(replay_buffer[, ...])

Replay buffer hook provider.

RewardNormalizer([decay, scale, eps, ...])

Reward normalizer hook.

SelectKeys(keys)

Selects keys in a TensorDict batch.

UpdateWeights(collector, update_weights_interval)

A collector weights update hook class.

TargetNetUpdaterHook(target_params_updater)

A hook for target parameters update.

UTDRHook(trainer)

Hook for logging Update-to-Data (UTD) ratio during async collection.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources