Shortcuts

Source code for torchtune.utils.precision

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import contextlib
from typing import (
    ContextManager,
    Dict,
    Generator,
    Iterable,
    List,
    Optional,
    Tuple,
    Union,
)

import torch
import torch.nn as nn
from pkg_resources import packaging

from torch.cuda.amp import GradScaler
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler

from torchtune.utils._device import _validate_device_from_env
from torchtune.utils.logging import get_logger

log = get_logger()

PRECISION_STR_TO_DTYPE: Dict[str, torch.dtype] = {
    "fp16": torch.float16,
    "bf16": torch.bfloat16,
    "fp32": torch.float32,
    "fp64": torch.float64,
}


def _set_float32_precision(precision: str = "high") -> None:
    """Sets the precision of float32 matrix multiplications and convolution operations.

    For more information, see the PyTorch docs:
    - https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html
    - https://pytorch.org/docs/stable/backends.html#torch.backends.cudnn.allow_tf32

    Args:
        precision (str): The setting to determine which datatypes to use for matrix multiplication and convolution operations.
    """
    if not torch.cuda.is_available():  # Not relevant for non-CUDA devices
        return
    # set precision for matrix multiplications
    torch.set_float32_matmul_precision(precision)
    # set precision for convolution operations
    if precision == "highest":
        torch.backends.cudnn.allow_tf32 = False
    else:
        torch.backends.cudnn.allow_tf32 = True


[docs]def list_dtypes() -> List[str]: """ Return a list of supported dtypes for finetuning. """ return list(PRECISION_STR_TO_DTYPE.keys())
def verify_bf16_support(): return ( torch.cuda.is_available() and torch.version.cuda and torch.cuda.is_bf16_supported() and packaging.version.parse(torch.version.cuda).release >= (11, 0) and torch.distributed.is_nccl_available() and torch.cuda.nccl.version() >= (2, 10) )
[docs]def get_dtype( dtype: Optional[str] = None, device: Optional[torch.device] = None ) -> torch.dtype: """Get the torch.dtype corresponding to the given precision string. Note: If bf16 precision is requested with a CUDA device, we verify whether the device indeed supports bf16 kernels. If not, a ``RuntimeError`` is raised. Args: dtype (Optional[str]): The precision dtype. Default: ``None``, in which we default to torch.float32 device (Optional[torch.device]): Device in use for training. Only CUDA and CPU devices are supported. If a CUDA device is passed in, additional checking is done to ensure that the device supports the requested precision. Default: ``None``, in which case a CUDA device is assumed. Raises: ValueError: if precision isn't supported by the precision utils RuntimeError: if bf16 precision is requested but not available on this hardware. Returns: torch.dtype: The corresponding torch.dtype. """ # None defaults to float32 if dtype is None: return torch.float32 # Convert to torch.dtype torch_dtype = PRECISION_STR_TO_DTYPE.get(dtype, dtype) # dtype must be one of the supported precisions if torch_dtype not in PRECISION_STR_TO_DTYPE.values(): raise ValueError( f"Dtype {torch_dtype} must be one of {', '.join(list_dtypes())} for finetuning." ) # TODO (rohan-varma): prefer to use get_default_device() here to figure out whether user is training on # CPU or GPU, but it is not supported in versions of torch we test. if ( torch_dtype == torch.bfloat16 and device != torch.device("cpu") and not verify_bf16_support() ): raise RuntimeError( "bf16 precision was requested but not available on this hardware. Please use fp32 precision instead." ) return torch_dtype
def get_gradient_scaler(fsdp: bool = False) -> Union[GradScaler, ShardedGradScaler]: """ Returns a gradient scaler for mixed-precision training. Args: fsdp (bool): Whether FSDP is being used for training, in which case a shard-aware gradient scaler is returned. Returns: Union[GradScaler, ShardedGradScaler]: Gradient scaler object """ return GradScaler(enabled=True) if not fsdp else ShardedGradScaler(enabled=True) def get_autocast(dtype: torch.dtype, device: torch.device) -> ContextManager: """ Intelligently determines, based on the dtype if mixed precision training is supported and returns the builtin torch.autocast if applicable. Reference: https://pytorch.org/docs/stable/amp.html#torch.autocast Args: dtype (torch.dtype): dtype used to determine if mixed precision training is used. device (torch.device): Pytorch device. Returns: Autocast manager object if using half precision, otherwise null context """ manager = None if dtype in (torch.float16, torch.bfloat16): # Note some devices do not support autocasting, and will raise an error. _validate_device_from_env(device) return torch.autocast( device_type=device.type, dtype=dtype, ) else: return contextlib.nullcontext() @contextlib.contextmanager def set_default_dtype(dtype: torch.dtype) -> Generator[None, None, None]: old_dtype = torch.get_default_dtype() torch.set_default_dtype(dtype) try: yield finally: torch.set_default_dtype(old_dtype) def validate_expected_param_dtype( named_params: Iterable[Tuple[str, nn.Parameter]], dtype: torch.dtype ) -> None: """ Validates that all input parameters have the expected dtype. Args: named_params (Iterable[Tuple[str, nn.Parameter]]): Iterable of named parameters. dtype (torch.dtype): Expected dtype. Raises: ValueError: If any parameter has a different dtype than `dtype`. """ for name, param in named_params: if param.dtype != dtype: raise ValueError( f"Parameter {name} has dtype {param.dtype}, but expected {dtype}" )

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources