Shortcuts

Source code for torchao.float8.fsdp_utils

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import math
from typing import Any, List, Optional, Set, Tuple

import torch
import torch.nn as nn
import torch.utils._pytree as pytree
from torch._prims_common import suggest_memory_format

from torchao.float8.float8_scaling_utils import (
    hp_tensor_to_float8_dynamic,
)
from torchao.float8.float8_tensor import (
    Float8Tensor,
    GemmInputRole,
    LinearMMConfig,
    hp_tensor_and_scale_to_float8,
)
from torchao.float8.float8_utils import EPS


[docs]@torch.no_grad() def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None: """ Calculate scale dynamically for all float8 parameters. This should be run after the optimizer step. It performs a single all-reduce to compute the scales for all float8 weights. Example usage: model(input).sum().backward() optim.step() precompute_float8_dynamic_scale_for_fsdp(model) """ from torch.distributed._tensor import DTensor from torchao.float8.float8_linear import Float8Linear float8_linears: List[Float8Linear] = [ m for m in module.modules() if isinstance(m, Float8Linear) and isinstance(m.weight, DTensor) and isinstance(m.weight._local_tensor, WeightWithDynamicFloat8CastTensor) ] weights: List[DTensor] = [float8_linear.weight for float8_linear in float8_linears] target_dtypes: Set[torch.dtype] = { float8_linear.config.cast_config_weight.target_dtype for float8_linear in float8_linears } if not weights: return (target_dtype,) = target_dtypes # inf-norm is equivalent to max(abs(w)) max_weights = torch._foreach_norm(weights, ord=math.inf) # Partial amax_tensor = torch.stack(max_weights) # Partial # clamp is dispatched through DTensor # it will issue a single all-reduce amax_tensor = torch.clamp(amax_tensor, EPS) # Replicate # keep consistent with float8_utils.amax_to_scale # torch.compile and eager show different numerics for 1.0 / float32, # upcast to float64 to ensure same numeric between compile and eager origin_dtype = amax_tensor.dtype amax_tensor = amax_tensor.to(torch.float64) scale_tensor = torch.finfo(target_dtype).max / amax_tensor # Replicate if origin_dtype is torch.float16: scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max) local_scale_tensor = scale_tensor.to_local().to(torch.float32) for i, float8_linear in enumerate(float8_linears): float8_linear.weight._local_tensor._precomputed_scale = local_scale_tensor[i]
# FSDP pads its local tensor on dim-0. The subclass should be preserved such # that the padded local tensor (and any transformations like copying to GPU) # is of the subclass as well. _ops_to_preserve_subclass = { torch.ops.aten.empty_like.default, torch.ops.aten.new_zeros.default, torch.ops.aten.slice.Tensor, torch.ops.aten.copy_.default, torch.ops.aten.view.default, torch.ops.aten.as_strided.default, torch.ops.aten._to_copy.default, torch.ops.aten._pin_memory.default, torch.ops.aten.split.Tensor, torch.ops.aten.clone.default, } # How Tensor Parallel (TP) and FSDP2 work # Initialization: apply TP first then FSDP2 # nn.Linear(weight=torch.Tensor) # | # | apply float8 linear, `convert_to_float8_training` # | # Float8Linear(weight=WeightWithDynamicFloat8CastTensor) # | # | apply tensor parallel, `parallelize_module` shards rowwise/colwise # | # Float8Linear(weight=DTensor(local_tensor=WeightWithDynamicFloat8CastTensor, # device_mesh=DeviceMesh([0, 1], mesh_dim_names=('tp',)), # placements=(Shard(dim=0),))) # | # | apply FSDP2, `fully_shard` shards rowwise (dim=0) # | # Float8Linear(weight=DTensor(local_tensor=WeightWithDynamicFloat8CastTensor, # device_mesh=DeviceMesh([[0, 1], [2, 3]], mesh_dim_names=('dp', 'tp')), # placements=(Shard(dim=0), Shard(dim=0)))) # Forward and backward: FSDP runs first then TP # Float8Linear(weight=DTensor(local_tensor=WeightWithDynamicFloat8CastTensor, # device_mesh=DeviceMesh([[0, 1], [2, 3]], mesh_dim_names=('dp', 'tp')), # placements=(Shard(dim=0), Shard(dim=0)))) # | # | FSDP unshards parameters within dp mesh # | # Float8Linear(weight=DTensor(local_tensor=WeightWithDynamicFloat8CastTensor, # device_mesh=DeviceMesh([0, 1], mesh_dim_names=('tp',)), # placements=(Shard(dim=0),))) # | # | TP compute with torch.mm(input, weight) class WeightWithDynamicFloat8CastTensor(torch.Tensor): @staticmethod def __new__( cls, tensor: torch.Tensor, linear_mm_config: LinearMMConfig, dtype: torch.dtype, precomputed_scale: Optional[torch.Tensor] = None, ): return torch.Tensor._make_wrapper_subclass( cls, tensor.size(), strides=tensor.stride(), storage_offset=tensor.storage_offset(), memory_format=suggest_memory_format(tensor), dtype=tensor.dtype, layout=tensor.layout, device=tensor.device, pin_memory=tensor.is_pinned(), requires_grad=tensor.requires_grad, ) def __init__( self, tensor: torch.Tensor, linear_mm_config: LinearMMConfig, dtype: torch.dtype, precomputed_scale: Optional[torch.Tensor] = None, ): self._tensor = tensor self._linear_mm_config = linear_mm_config self._dtype = dtype # for dynamic scaling # `precompute_float8_dynamic_scale_for_fsdp` calculates scales # for all float8 parameters after optimizer step self._precomputed_scale = precomputed_scale @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): if func == torch.ops.aten.detach.default: return WeightWithDynamicFloat8CastTensor( args[0]._tensor, args[0]._linear_mm_config, args[0]._dtype ) mm_config: Optional[LinearMMConfig] = None dtype: Optional[torch.dtype] = None def unwrap(t): nonlocal mm_config if mm_config is None: mm_config = t._linear_mm_config else: assert t._linear_mm_config == mm_config nonlocal dtype if dtype is None: dtype = t._dtype else: assert t._dtype == dtype return t._tensor args, kwargs = pytree.tree_map_only( WeightWithDynamicFloat8CastTensor, unwrap, (args, kwargs or {}) ) out = func(*args, **kwargs) if func not in _ops_to_preserve_subclass: return out return pytree.tree_map_only( torch.Tensor, lambda x: WeightWithDynamicFloat8CastTensor(x, mm_config, dtype), out, ) def __tensor_flatten__(self): tensors = ["_tensor"] if self._precomputed_scale: tensors.append("_precomputed_scale") return tensors, {"mm_config": self._linear_mm_config, "dtype": self._dtype} @staticmethod def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): return WeightWithDynamicFloat8CastTensor( inner_tensors["_tensor"], flatten_spec["mm_config"], flatten_spec["dtype"], getattr(inner_tensors, "_precomputed_scale", None), ) def __repr__(self): return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, linear_mm_config={self._linear_mm_config}, dtype={self._dtype})" def fsdp_pre_all_gather(self, mesh): if self._precomputed_scale is not None: float8_tensor = hp_tensor_and_scale_to_float8( self._tensor, self._precomputed_scale, self._dtype, self._linear_mm_config, GemmInputRole.WEIGHT, ) else: float8_tensor = hp_tensor_to_float8_dynamic( self._tensor, self._dtype, self._linear_mm_config, reduce_amax=True, gemm_input_role=GemmInputRole.WEIGHT, device_mesh=mesh, ) return (float8_tensor._data,), (float8_tensor._scale,) def fsdp_post_all_gather( self, all_gather_outputs: Tuple[torch.Tensor, ...], metadata: Any, param_dtype: torch.dtype, *, out: Optional[torch.Tensor] = None, ): (data,) = all_gather_outputs (scale,) = metadata if out is not None: from torch.distributed._tensor import DTensor if isinstance(out, Float8Tensor): out._scale = scale elif isinstance(out, DTensor) and isinstance( out._local_tensor, Float8Tensor ): out._local_tensor._scale = scale else: raise RuntimeError( f"out must be a Float8Tensor or DTensor(_local_tensor=Float8Tensor), but got {out}" ) return return Float8Tensor( data, scale, param_dtype, self._linear_mm_config, gemm_input_role=GemmInputRole.WEIGHT, ), (data,)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources