Shortcuts

Source code for torchtune.training.lr_schedulers

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import math
from typing import Union

import torch
from torch.optim.lr_scheduler import LambdaLR
from torchtune.training.memory import OptimizerInBackwardWrapper


[docs]def get_cosine_schedule_with_warmup( optimizer: torch.optim.Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1, ) -> LambdaLR: """ Create a learning rate schedule that linearly increases the learning rate from 0.0 to lr over ``num_warmup_steps``, then decreases to 0.0 on a cosine schedule over the remaining ``num_training_steps-num_warmup_steps`` (assuming ``num_cycles`` = 0.5). This is based on the Hugging Face implementation https://github.com/huggingface/transformers/blob/v4.23.1/src/transformers/optimization.py#L104. Args: optimizer (torch.optim.Optimizer): The optimizer for which to schedule the learning rate. num_warmup_steps (int): The number of steps for the warmup phase. num_training_steps (int): The total number of training steps. num_cycles (float): The number of waves in the cosine schedule. Defaults to 0.5 (decrease from the max value to 0 following a half-cosine). last_epoch (int): The index of the last epoch when resuming training. Defaults to -1 Returns: torch.optim.lr_scheduler.LambdaLR with the appropriate schedule. """ def lr_lambda(current_step: int) -> float: # linear warmup phase if current_step < num_warmup_steps: return current_step / max(1, num_warmup_steps) # cosine progress = (current_step - num_warmup_steps) / max( 1, num_training_steps - num_warmup_steps ) cosine_lr_multiple = 0.5 * ( 1.0 + math.cos(math.pi * num_cycles * 2.0 * progress) ) return max(0.0, cosine_lr_multiple) return LambdaLR(optimizer, lr_lambda, last_epoch)
[docs]def get_lr( optimizer: Union[torch.optim.Optimizer, OptimizerInBackwardWrapper] ) -> float: """ Full_finetune_distributed and full_finetune_single_device assume all optimizers have the same LR, here to validate whether all the LR are the same and return if True. Args: optimizer (Union[torch.optim.Optimizer, OptimizerInBackwardWrapper]): A general optimizer input that could whether be a general optimizer or an optimizer warpper based on optimizer_in_backward. Returns: lr (float): The learning rate of the input optimizers. Raises: RuntimeError: If the learning rates of the input optimizer are not the same. """ if isinstance(optimizer, OptimizerInBackwardWrapper): param_groups = [] for param in optimizer.state_dict().values(): param_groups.append(param["param_groups"][0]) else: param_groups = optimizer.param_groups if len(param_groups) < 1: raise RuntimeError( f"Invalid optimizer param groups with len of: {len(param_groups)}" ) # LR Schedulers are the same across all param groups for full_finetune right now lr = param_groups[0]["lr"] for group in param_groups: if group["lr"] != lr: raise RuntimeError("LR Schedulers are different across all param groups ") return lr

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources