OfflineToOnlineTrainer#
- class torchrl.trainers.algorithms.OfflineToOnlineTrainer(*args, **kwargs)[source]#
A SAC trainer for the offline-pretrain -> online-finetune transition.
See also
OfflineToOnlineTrainerConfigfor the Hydra configuration counterpart.Builds on
SACTrainer, swapping the plain replay buffer for anOfflineToOnlineReplayBuffer. Each collected batch is routed to the online buffer while optimization samples a mixed batch whose offline fraction is linearly annealed to zero overanneal_framesframes – warm-starting the policy on offline data and smoothly handing it over to its own online experience. All other SAC behaviour (target-net updates, weight sync, logging) is inherited.- Parameters:
collector (BaseCollector) – the data collector for online interactions.
total_frames (int) – total number of frames to collect.
frame_skip (int) – frames skipped between policy updates.
optim_steps_per_batch (int) – optimization steps per collected batch.
loss_module (LossModule) – the SAC loss module.
replay_buffer (OfflineToOnlineReplayBuffer) – the offline-to-online buffer.
- Keyword Arguments:
anneal_frames (int, optional) – frames over which
offline_fractiondecays to 0. Defaults tototal_frames; pass<= 0to keep the fraction fixed.batch_size (int, optional) – replay-buffer sampling batch size.
See
SACTrainerfor the remaining keyword arguments.Note
Experimental/prototype feature; the API may change.
- load_from_file(file: str | Path, **kwargs) Trainer#
Loads a file and its state-dict in the trainer.
Keyword arguments are passed to the
load()function. They are ignored whenCKPT_BACKEND=memmap.Note
When
CKPT_BACKEND=torch,weights_only=Trueis set by default for safer deserialization. Passweights_only=Falseexplicitly only if you have custom (non-stdlib) objects in your state dict.
- request_stop(reason: str | None = None) None#
Signal that training should stop at the next loop boundary.