Source code for torchtune.utils._distributed
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
from itertools import chain
from typing import Any, Callable, cast, Dict, Set, Tuple, Type
import torch
import torch.distributed as dist
from packaging import version
from torch import nn
from torch.distributed._tensor import distribute_tensor, DTensor
from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
_CHECKPOINT_WRAPPED_MODULE,
)
from torch.distributed.checkpoint.state_dict import _init_optim_state
from torch.distributed.fsdp import ShardingStrategy
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.optim import Optimizer
from torchao.dtypes.nf4tensor import NF4Tensor, to_nf4
from torchtune import modules
from torchtune.modules.peft.lora import (
_lora_a_init_params,
_lora_b_init_params,
LoRALinear,
)
from torchtune.utils._device import get_device
from torchtune.utils.logging import get_logger
_log: logging.Logger = get_logger()
FSDPPolicyType: Type = Callable[[nn.Module, bool, int], bool]
FSDPPolicyType.__doc__ = """
A datatype for a function that can be used as an FSDP wrapping policy.
In particular, this type denotes a function that can accept an nn.Module, a boolean flag, and an integer
and return a boolean indicating whether the module should be wrapped with FSDP. Objects of this type can
be directly passed into PyTorch FSDP's ``auto_wrap_policy`` argument to specify how FSDP wraps submodules.
The below function serves as an example of creating and returning a function that obeys the contract of
``FSDPPolicyType``::
def get_fsdp_policy(module: nn.Module, modules_to_wrap: Set[Type], min_num_params: int):
def my_fsdp_policy(module: nn.Module, modules_to_wrap: Set[Type], recurse: bool, min_num_params: int) -> bool:
if recurse:
return True
# Wrap layers that are of type in ``modules_to_wrap`` and layers with more than min_num_params
return isinstance(module, tuple(modules_to_wrap)) or sum(p.numel() for p in module.parameters()) > 1000
return functools.partial(my_fsdp_policy, modules_to_wrap=modules_to_wrap)
Please see documentation of ``auto_wrap_policy`` at https://pytorch.org/docs/stable/fsdp.html for additional details.
"""
_valid_distributed_single_node_nnodes = ["1:1", "1"]
def _get_sharding_strategy(strategy: str) -> ShardingStrategy:
"""Helper function to convert sharding strategy strings to ShardingStrategy enum."""
return getattr(ShardingStrategy, strategy)
[docs]def is_distributed() -> bool:
"""Check if all environment variables required to initialize torch.distributed are set
and distributed is properly installed. This indicates a distributed run.
https://pytorch.org/docs/stable/distributed.html#environment-variable-initialization
Checks the following conditions:
* torch.distributed is available
* master port and master address environment variables are set
* world size is >1
* rank environment variable is set
Returns:
bool: True if all of the above conditions hold, False otherwise.
"""
port = os.environ.get("MASTER_PORT", "")
addr = os.environ.get("MASTER_ADDR", "")
size = int(os.environ.get("WORLD_SIZE", 1))
rank = int(os.environ.get("RANK", -1))
avlb = dist.is_available()
return bool(port and addr and size >= 1 and rank >= 0 and avlb)
def _broadcast_tensor(tensor: torch.Tensor, src: int = 0) -> torch.Tensor:
"""Broadcasts a tensor from a source to all other processes.
Args:
tensor (torch.Tensor): Tensor to broadcast.
src (int, optional): Source rank. Defaults to 0.
Returns:
torch.Tensor: Broadcasted tensor.
"""
if dist.is_available() and dist.is_initialized():
device = tensor.device
if dist.get_backend() == "nccl":
tensor = tensor.to(get_device("cuda"))
dist.broadcast(tensor, src=src, group=None)
return tensor.to(device)
else:
return tensor
[docs]def init_distributed(**kwargs: Dict[str, Any]) -> bool:
"""Initialize process group required for ``torch.distributed``.
Args:
**kwargs (Dict): Additional arguments to pass to torch.distributed.init_process_group.
Returns:
bool: True if torch.distributed is initialized.
Raises:
RuntimeError: If torch.distributed is already initialized.
"""
if is_distributed():
if dist.is_initialized():
raise RuntimeError("torch.distributed already initialized.")
dist.init_process_group(**kwargs)
return True
else:
return False
def set_torch_num_threads() -> None:
"""
Sets the number of threads used by torch to utilize all physical CPU
cores for intra-op parallelism. Currently, this function sets num_threads
to be the number of physical CPU cores divided by the number of GPUs as we
use one process per GPU, and this avoids CPU oversubscription. Note that this is
currently a rough approximation, and doesn't take into account environments where
things like CPU affinity is set.
"""
num_threads = os.cpu_count() // (
torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
)
torch.set_num_threads(num_threads)
_log.info(f"Set intra op parallelism no. of threads to {num_threads}")
[docs]def get_world_size_and_rank() -> Tuple[int, int]:
"""Function that gets the current world size (aka total number
of ranks) and rank number of the current process in the default process group.
Returns:
Tuple[int, int]: world size, rank
"""
if dist.is_available() and dist.is_initialized():
return torch.distributed.get_world_size(), torch.distributed.get_rank()
else:
return 1, 0
def validate_no_params_on_meta_device(model: nn.Module) -> None:
"""
Utility to validate that model has no params or buffers on meta device.
If a meta param or buffer is found, an error indicating the param name will
be raised.
Args:
model (nn.Module): model to check for meta params
Raises:
RuntimeError: If meta params or buffers exist in model
"""
for n, p in chain(model.named_parameters(), model.named_buffers()):
if p.is_meta:
raise RuntimeError(f"Unexpected param or buffer {n} on meta device.")
def contains_fsdp(model: nn.Module) -> bool:
"""
Checks if the model contains FSDP.
Args:
model (nn.Module): Model to check.
Returns:
bool: True if the model contains FSDP, False otherwise.
"""
return any(
isinstance(m, torch.distributed.fsdp.FullyShardedDataParallel)
for m in model.modules()
)
def _dummy_reset_params(x: nn.Module) -> None:
"""
Dummy method for patching no-op reset_parameters() when using
FSDP with meta device.
"""
return
def prepare_model_for_fsdp_with_meta_device(model: nn.Module) -> nn.Module:
"""
Dynamically define reset_parameters on every submodule of the model. For LoRA models,
ensure that the FSDP contract of reset_parameters only modifying a module's directly-owned
parameters is satisfied. More details here: https://github.com/pytorch/pytorch/issues/104187.
Args:
model (nn.Module): model class to prepare for usage with FSDP and meta device.
Returns:
nn.Module: Model with reset_parameters defined on every submodule.
In the case of a LoRA model, we override the default reset_parameters of nn.Linear.
Raises:
RuntimeError: if model contains submodule with non-callable attribute reset_parameters
"""
for k, v in model.named_modules():
# If the module does not have reset_parameters defined, we define
# a no-op reset_parameters method to satisfy FSDP's contract.
reset_params = getattr(v, "reset_parameters", None)
if reset_params is not None and not callable(reset_params):
raise RuntimeError(
f"Cannot override existing reset_parameters variable for FSDP init in {k}"
)
if reset_params is None:
v.reset_parameters = _dummy_reset_params.__get__(v)
# This will define reset_parameters for LoRA weight initialization
# directly on any LoRALinear submodules lora_a and lora_b.
if isinstance(v, LoRALinear):
v.lora_a.reset_parameters = _lora_a_init_params.__get__(v.lora_a)
v.lora_b.reset_parameters = _lora_b_init_params.__get__(v.lora_b)
return model
[docs]def lora_fsdp_wrap_policy(modules_to_wrap: Set[Type]) -> FSDPPolicyType:
"""
A default policy for wrapping models trained with LoRA using FSDP.
FSDP's default behavior is to allocate gradients at the level of FSDP-wrapped modules.
This means that if any parameter in a given FSDP-wrapped module requires gradients, then memory will be
allocated for gradients for the entire module.
In the case of LoRA, where only LoRA A and B matrices are trainable, this means that
we need to wrap LoRA A and B submodules in their own FSDP units to
maximize memory savings. After this is done, model will also be hierarchically wrapped
based on nn.Module types specified in ``modules_to_wrap``. This function assumes that
(a) LoRA's A and B matrices are the only trainable weights in the entire model, and
(b) we have already set ``requires_grad = True`` on LoRA params.
Args:
modules_to_wrap (Set[Type]): nn.Module types to recursively wrap
Returns:
FSDPPolicyType: Wrapping policy that can be passed into ``FullyShardedDataParallel``. Please see
documentation for :const:`~torchtune.utils.FSDPPolicyType` for additional details.
"""
def lora_wrap_fsdp(module: nn.Module, recurse: bool, **kwargs):
if recurse:
return True
# Assumes lora_a and lora_b are nn.Linears that are the
# only trainable modules in the entire network. Wraps
# these in separate FSDP unit to work around FSDP allocating
# extra gradient memory when wrapped with other modules.
if hasattr(module, "weight") and module.weight.requires_grad:
return True
return isinstance(module, tuple(modules_to_wrap))
return lora_wrap_fsdp
def load_from_full_model_state_dict(
model: "FSDPModule",
full_sd: Dict[str, Any],
device: torch.device,
is_rank_zero: bool,
):
"""
Converting full state dict into a sharded state dict
and loading it into FSDP model
- 'full' means plain tensor
- 'sharded' means `DTensor` where reach rank has a shard of the plain tensor
- `is_rank_zero` matters if only rank 0 pass in non-empty `full_sd` and
we need to broadcast from rank 0
"""
meta_sharded_sd = model.state_dict()
sharded_sd = {}
for param_name, full_tensor in full_sd.items():
sharded_meta_param = meta_sharded_sd.get(param_name)
full_tensor = full_tensor.to(sharded_meta_param.dtype).to(device)
if isinstance(sharded_meta_param._local_tensor, NF4Tensor):
full_tensor = to_nf4(full_tensor)
# replicating logic from `_fsdp_param.py`` `_init_sharded_param`
# otherwise `distribute_tensor(DTensor(local=NF4))`
# requires dispatching `c10d.scatter_``
# long-term solution is `swap_tensor`
mesh = sharded_meta_param.device_mesh
if mesh.ndim > 1:
raise NotImplementedError(f"only support 1D FSDP but got {mesh.ndim=}")
shard_mesh_dim = 0
shard_world_size = mesh.size(shard_mesh_dim)
shard_rank = cast(
torch.distributed.ProcessGroup, mesh.get_group(shard_mesh_dim)
).rank()
chunk = list(torch.chunk(full_tensor, shard_world_size, dim=0))[shard_rank]
sharded_param = full_tensor.new_zeros(chunk.size())
sharded_param[: chunk.size(0)].copy_(chunk)
# BC-breaking change to DTensor API in https://github.com/pytorch/pytorch/pull/128112
# TODO: change to from_local API (need to add view support for NF4)
if version.parse(torch.__version__) >= version.parse("2.4.0.dev20240606"):
sharded_tensor = DTensor(
local_tensor=sharded_param,
spec=DTensorSpec(
mesh=sharded_meta_param.device_mesh,
placements=sharded_meta_param.placements,
tensor_meta=TensorMeta(
shape=sharded_meta_param.size(),
dtype=sharded_meta_param.dtype,
stride=sharded_meta_param.stride(),
),
),
requires_grad=sharded_meta_param.requires_grad,
)
else:
sharded_tensor = DTensor(
sharded_param,
sharded_meta_param.device_mesh,
sharded_meta_param.placements,
shape=sharded_meta_param.size(),
dtype=sharded_meta_param.dtype,
requires_grad=sharded_meta_param.requires_grad,
stride=sharded_meta_param.stride(),
)
else:
sharded_tensor = distribute_tensor(
full_tensor,
sharded_meta_param.device_mesh,
sharded_meta_param.placements,
)
sharded_sd[param_name] = nn.Parameter(sharded_tensor)
# choose `assign=True` since we cannot call `copy_` on meta tensor
return model.load_state_dict(sharded_sd, strict=False, assign=True)
def get_full_model_state_dict(
model: "FSDPModule",
is_rank_zero: bool,
) -> Dict[str, Any]:
"""
Converting sharded state dict into a full state dict on cpu
Returning non-empty result on rank0 to avoid peaking cpu memory
"""
sharded_sd = model.state_dict()
cpu_state_dict = {}
has_nf4 = any(
isinstance(param._local_tensor, NF4Tensor) for param in model.parameters()
)
if has_nf4:
from torch.distributed._composable.fsdp.fully_shard import FSDPModule
# Iterating from lowerer modules to higher
# Unsharding lora adapters before unsharding transformer block
for module_name, module in reversed(list(model.named_modules())):
if not isinstance(module, FSDPModule):
continue
module.unshard(async_op=False)
if is_rank_zero:
module_name = module_name.replace(f".{_CHECKPOINT_WRAPPED_MODULE}", "")
for local_fqn, param in module.named_parameters():
local_fqn = local_fqn.replace(f".{_CHECKPOINT_WRAPPED_MODULE}", "")
if len(module_name) > 0:
full_fqn = module_name + "." + local_fqn
else:
full_fqn = local_fqn
if full_fqn in cpu_state_dict:
# Iterate over every param in every module bottoms-up
# When lower TransformerBlock gets unsharded,
# we insert (full_fqn, full_tensor) into cpu_state_dict.
# When higher Transformer gets unsharded, we avoid updating
# params from lower TransformerBlockonly again. Instead, only updating
# tok_embeddings etc that belongs to Transformer
continue
if isinstance(param, NF4Tensor):
# upcasting NF4 to original dtype
param = param.to(param.dtype)
if isinstance(param, DTensor):
raise AssertionError(
f"Internal error: expect unsharded {full_fqn} in plain torch.Tensor but got DTensor."
" Might be a bug in get_full_model_state_dict"
)
cpu_state_dict[full_fqn] = param.cpu()
module.reshard()
else:
for param_name, sharded_param in sharded_sd.items():
full_param = sharded_param.full_tensor()
if is_rank_zero:
cpu_state_dict[param_name] = full_param.cpu()
else:
del full_param
return cpu_state_dict
def get_full_optimizer_state_dict(
opt: Optimizer,
is_rank_zero: bool,
) -> Dict[str, Any]:
"""
Converting optimizer state from sharded to full
For example, "exp_avg" in AdamW is `DTensor`,
"exp_avg.full_tensor()" converts it to plain tensor on rank 0
Returning non-empty cpu state dict on rank 0
"""
sharded_sd = opt.state_dict()
sharded_state = sharded_sd["state"]
full_state = {}
for group_id, sharded_group in sharded_state.items():
group_state = {}
for attr, sharded_tensor in sharded_group.items():
if isinstance(sharded_tensor, DTensor):
# "exp_avg" in AdamW is `DTensor`
full_tensor = sharded_tensor.full_tensor()
else:
# "step" in AdamW is plain tensor
full_tensor = sharded_tensor
if is_rank_zero:
group_state[attr] = full_tensor.cpu()
else:
del full_tensor
if is_rank_zero:
full_state[group_id] = group_state
else:
del group_state
if is_rank_zero:
return {
"param_groups": sharded_sd["param_groups"],
"state": full_state,
}
else:
return {}
def load_from_full_optimizer_state_dict(
opt: Optimizer,
full_sd: Dict[str, Any],
device: torch.device,
) -> Dict[str, Any]:
"""
Converting full optimizer state to sharded state dict
and loading it into optimizer
"""
PARAMS = "params" # noqa: N806
_init_optim_state(opt)
param_groups = opt.state_dict()["param_groups"]
state = opt.state_dict()["state"]
full_param_groups = full_sd["param_groups"]
full_state = full_sd["state"]
for param_group, full_param_group in zip(param_groups, full_param_groups):
for key, value in full_param_group.items():
if key == PARAMS:
continue
param_group[key] = value
for pid, full_pid in zip(param_group[PARAMS], full_param_group[PARAMS]):
if pid not in state:
continue
param_state = state[pid]
full_param_state = full_state[full_pid]
for attr, full_tensor in full_param_state.items():
sharded_tensor = param_state[attr]
if isinstance(sharded_tensor, DTensor):
# exp_avg is DTensor
param_state[attr] = distribute_tensor(
full_tensor,
sharded_tensor.device_mesh,
sharded_tensor.placements,
)
else:
# step is plain tensor
param_state[attr] = full_tensor
opt.load_state_dict(
{
"param_groups": param_groups,
"state": state,
}
)
[docs]def get_full_finetune_fsdp_wrap_policy(
memory_efficient_fsdp_wrap: bool, modules_to_wrap: Set[Type]
) -> FSDPPolicyType:
"""
Retrieves an FSDP wrapping policy based on the specified flags ``memory_efficient_fsdp_wrap`` and
``modules_to_wrap``. Specifically, if ``memory_efficient_fsdp_wrap`` is set to ``True``, the returned
policy will wrap the model's token embedding and output projection in addition to the modules specified
to maximize memory savings.
Args:
memory_efficient_fsdp_wrap (bool): If ``True``, will also wrap embedding and output projection layers with FSDP.
modules_to_wrap (Set[Type]): Set of module types to wrap.
Note:
``memory_efficient_fsdp_wrap`` memory improvements have currently only been verified on llama3 workloads
where they provide ~15% memory improvement (when used alongside AC memory efficient wrapping). Other workloads
have not been verified and may not see the same improvements.
Returns:
FSDPPolicyType: Wrapping policy that can be passed into ``FullyShardedDataParallel`` as the ``auto_wrap_policy``
argument. Please see documentation for :const:`~torchtune.utils.FSDPPolicyType` for additional details.
"""
if memory_efficient_fsdp_wrap:
return _memory_efficient_wrap_policy(modules_to_wrap=modules_to_wrap)
else:
return ModuleWrapPolicy(modules_to_wrap)
def _memory_efficient_wrap_policy(modules_to_wrap: Set[Type]) -> FSDPPolicyType:
"""
A default policy for memory efficient wrapping for full finetuning using FSDP. Specifically,
this will wrap the model's token embedding and output projection into their own FSDP units to
maximize memory savings. This helps especially if these layers are particularly large,
such as due to a large embedding size.
After this is done, model will also be hierarchically wrapped
based on nn.Module types specified in ``modules_to_wrap``. This function assumes that the
input model has an attribute ``output`` that is a nn.Linear which is the model's output projection.
Args:
modules_to_wrap (Set[Type]): nn.Module types to recursively wrap
Returns:
FSDPPolicyType: Wrapping policy that can be passed into ``FullyShardedDataParallel``.
"""
modules_to_wrap.add(torch.nn.Embedding)
def llama3_wrap(module: nn.Module, recurse: bool, **kwargs):
# Label that output_proj should be wrapped individually.
if isinstance(module, modules.TransformerDecoder):
module.output._wrap = True
if recurse:
return True
# Wrap output_proj individually.
if getattr(module, "_wrap", False):
return True
return isinstance(module, tuple(modules_to_wrap))
return llama3_wrap