Source code for torchtune.training._distributed
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
from itertools import chain
from typing import Any, Callable, cast, Dict, List, Optional, Tuple
import torch
import torch.distributed as dist
from torch import nn
from torch.distributed._composable.fsdp import CPUOffloadPolicy, fully_shard
from torch.distributed._tensor import distribute_tensor, DTensor
from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta
from torch.distributed.checkpoint.state_dict import _init_optim_state
from torch.distributed.fsdp import ShardingStrategy
from torch.optim import Optimizer
from torchao.dtypes.nf4tensor import NF4Tensor, to_nf4
from torchtune.modules import TransformerDecoder
from torchtune.utils import get_logger
from torchtune.utils._device import get_device
_log: logging.Logger = get_logger()
_valid_distributed_single_node_nnodes = ["1:1", "1"]
def _get_sharding_strategy(strategy: str) -> ShardingStrategy:
"""Helper function to convert sharding strategy strings to ShardingStrategy enum."""
return getattr(ShardingStrategy, strategy)
[docs]def is_distributed() -> bool:
"""Check if all environment variables required to initialize torch.distributed are set
and distributed is properly installed. This indicates a distributed run.
https://pytorch.org/docs/stable/distributed.html#environment-variable-initialization
Checks the following conditions:
* torch.distributed is available
* master port and master address environment variables are set
* world size is >1
* rank environment variable is set
Returns:
bool: True if all of the above conditions hold, False otherwise.
"""
port = os.environ.get("MASTER_PORT", "")
addr = os.environ.get("MASTER_ADDR", "")
size = int(os.environ.get("WORLD_SIZE", 1))
rank = int(os.environ.get("RANK", -1))
avlb = dist.is_available()
return bool(port and addr and size >= 1 and rank >= 0 and avlb)
def _broadcast_tensor(tensor: torch.Tensor, src: int = 0) -> torch.Tensor:
"""Broadcasts a tensor from a source to all other processes.
Args:
tensor (torch.Tensor): torch.Tensor to broadcast.
src (int, optional): Source rank. Defaults to 0.
Returns:
torch.Tensor: Broadcasted tensor.
"""
if dist.is_available() and dist.is_initialized():
device = tensor.device
if dist.get_backend() == "nccl":
tensor = tensor.to(get_device("cuda"))
dist.broadcast(tensor, src=src, group=None)
return tensor.to(device)
else:
return tensor
[docs]def init_distributed(**kwargs: Dict[str, Any]) -> bool:
"""Initialize process group required for ``torch.distributed``.
Args:
**kwargs (Dict[str, Any]): Additional arguments to pass to torch.distributed.init_process_group.
Returns:
bool: True if torch.distributed is initialized.
Raises:
RuntimeError: If torch.distributed is already initialized.
"""
if is_distributed():
if dist.is_initialized():
raise RuntimeError("torch.distributed already initialized.")
dist.init_process_group(**kwargs)
return True
else:
return False
def set_torch_num_threads() -> None:
"""
Sets the number of threads used by torch to utilize all physical CPU
cores for intra-op parallelism. Currently, this function sets num_threads
to be the number of physical CPU cores divided by the number of GPUs as we
use one process per GPU, and this avoids CPU oversubscription. Note that this is
currently a rough approximation, and doesn't take into account environments where
things like CPU affinity is set.
"""
num_threads = os.cpu_count() // (
torch.cuda.device_count() if torch.cuda.is_available() else 1
)
torch.set_num_threads(num_threads)
_log.info(f"Set intra op parallelism no. of threads to {num_threads}")
[docs]def get_world_size_and_rank() -> Tuple[int, int]:
"""Function that gets the current world size (aka total number
of ranks) and rank number of the current process in the default process group.
Returns:
Tuple[int, int]: world size, rank
"""
if dist.is_available() and dist.is_initialized():
return torch.distributed.get_world_size(), torch.distributed.get_rank()
else:
return 1, 0
def validate_no_params_on_meta_device(model: nn.Module) -> None:
"""
Utility to validate that model has no params or buffers on meta device.
If a meta param or buffer is found, an error indicating the param name will
be raised.
Args:
model (nn.Module): model to check for meta params
Raises:
RuntimeError: If meta params or buffers exist in model
"""
for n, p in chain(model.named_parameters(), model.named_buffers()):
if p.is_meta:
raise RuntimeError(f"Unexpected param or buffer {n} on meta device.")
def load_from_full_model_state_dict(
model: "FSDPModule", # noqa
full_sd: Dict[str, Any],
device: torch.device,
is_rank_zero: bool,
strict: bool = False,
cpu_offload: bool = False,
):
"""
Converting full state dict into a sharded state dict
and loading it into FSDP model
- 'full' means plain tensor
- 'sharded' means `DTensor` where reach rank has a shard of the plain tensor
- `is_rank_zero` matters if only rank 0 pass in non-empty `full_sd` and
we need to broadcast from rank 0
"""
meta_sharded_sd = model.state_dict()
sharded_sd = {}
for param_name, full_tensor in full_sd.items():
sharded_meta_param = meta_sharded_sd.get(param_name)
full_tensor = full_tensor.to(sharded_meta_param.dtype).to(device)
if hasattr(sharded_meta_param, "_local_tensor") and isinstance(
sharded_meta_param._local_tensor, NF4Tensor
):
block_size = sharded_meta_param._local_tensor.block_size
scaler_block_size = sharded_meta_param._local_tensor.scaler_block_size
full_tensor = to_nf4(
full_tensor, block_size=block_size, scaler_block_size=scaler_block_size
)
# replicating logic from `_fsdp_param.py`` `_init_sharded_param`
# otherwise `distribute_tensor(DTensor(local=NF4))`
# requires dispatching `c10d.scatter_``
# long-term solution is `swap_tensor`
mesh = sharded_meta_param.device_mesh
if mesh.ndim > 1:
raise NotImplementedError(f"only support 1D FSDP but got {mesh.ndim=}")
shard_mesh_dim = 0
shard_world_size = mesh.size(shard_mesh_dim)
shard_rank = cast(
torch.distributed.ProcessGroup, mesh.get_group(shard_mesh_dim)
).rank()
chunk = list(torch.chunk(full_tensor, shard_world_size, dim=0))[shard_rank]
sharded_param = full_tensor.new_zeros(chunk.size())
sharded_param[: chunk.size(0)].copy_(chunk)
# TODO: change to from_local API (need to add view support for NF4)
sharded_tensor = DTensor(
local_tensor=sharded_param,
spec=DTensorSpec(
mesh=sharded_meta_param.device_mesh,
placements=sharded_meta_param.placements,
tensor_meta=TensorMeta(
shape=sharded_meta_param.size(),
dtype=sharded_meta_param.dtype,
stride=sharded_meta_param.stride(),
),
),
requires_grad=sharded_meta_param.requires_grad,
)
elif not hasattr(sharded_meta_param, "device_mesh"):
# In cases where parts of the model aren't sharded, some parameters will be plain tensors
sharded_tensor = full_tensor
else:
sharded_tensor = distribute_tensor(
full_tensor,
sharded_meta_param.device_mesh,
sharded_meta_param.placements,
)
if cpu_offload:
sharded_tensor = sharded_tensor.cpu()
sharded_sd[param_name] = nn.Parameter(sharded_tensor)
# choose `assign=True` since we cannot call `copy_` on meta tensor
return model.load_state_dict(sharded_sd, strict=strict, assign=True)
def _gather_nf4_tensor(sharded_param: nn.Parameter) -> nn.Parameter:
"""
Manually gather NF4Tensor parameter since it does not support all_gather
"""
mesh = sharded_param.device_mesh
nf4_tensor = sharded_param._local_tensor
quant_params, metadata = nf4_tensor.fsdp_pre_all_gather(mesh)
full_quant_params = []
for quant_param in quant_params:
d0, *dn = quant_param.shape
shape = (d0 * mesh.get_group().size(), *dn)
full_quant_param = torch.empty(
shape, device=quant_param.device, dtype=quant_param.dtype
)
dist.all_gather_into_tensor(
full_quant_param, quant_param, mesh.get_group(), async_op=False
)
full_quant_params.append(full_quant_param)
full_param, _ = nf4_tensor.fsdp_post_all_gather(
full_quant_params, metadata, nf4_tensor.dtype
)
return full_param
[docs]def gather_cpu_state_dict(
sharded_sd: Dict[str, DTensor], # noqa
is_rank_zero: bool,
device: Optional[torch.device] = None,
) -> Dict[str, Any]:
"""
Converting sharded state dict into a full state dict on CPU
Returning non-empty result only on rank0 to avoid peaking CPU memory
Args:
sharded_sd (Dict[str, DTensor]): Sharded state dict of DTensors
is_rank_zero (bool): flag to check if the process is on rank 0
device (Optional[torch.device]): device to use for sharded tensors. Default: None
Returns:
Dict[str, Any]: State dict on CPU
"""
cpu_state_dict = {}
for param_name, param in sharded_sd.items():
if param.is_cpu:
# Move back to device if offloaded to CPU
param = param.to(device)
if hasattr(param, "_local_tensor"):
if isinstance(param._local_tensor, NF4Tensor):
param = _gather_nf4_tensor(param)
else:
# Gather DTensor
param = param.full_tensor()
if isinstance(param, NF4Tensor):
# upcasting NF4 to original dtype
param = param.to(param.dtype)
if is_rank_zero:
cpu_state_dict[param_name] = param.cpu()
torch.distributed.barrier()
return cpu_state_dict
def get_full_optimizer_state_dict(
opt: Optimizer,
is_rank_zero: bool,
device: Optional[torch.device] = None,
) -> Dict[str, Any]:
"""
Converting optimizer state from sharded to full
For example, "exp_avg" in AdamW is `DTensor`,
"exp_avg.full_tensor()" converts it to plain tensor on rank 0
Returning non-empty cpu state dict on rank 0
"""
sharded_sd = opt.state_dict()
sharded_state = sharded_sd["state"]
full_state = {}
for group_id, sharded_group in sharded_state.items():
group_state = {}
for attr, sharded_tensor in sharded_group.items():
# without this, it may hang forever for +70B models.
torch.distributed.barrier()
# "exp_avg" in AdamW is `DTensor`
if isinstance(sharded_tensor, DTensor):
if sharded_tensor.is_cpu:
assert device is not None and device.type == "cuda", (
f"Expect cuda but got device={device}. "
"Please call get_full_optimizer_state_dict(..., device=self._device),"
" so DTensor can communicate over NCCL."
)
sharded_tensor = sharded_tensor.to(device)
full_tensor = sharded_tensor.full_tensor()
else:
# "step" in AdamW is plain tensor
full_tensor = sharded_tensor
if is_rank_zero:
group_state[attr] = full_tensor.cpu()
else:
del full_tensor
if is_rank_zero:
full_state[group_id] = group_state
else:
del group_state
if is_rank_zero:
return {
"param_groups": sharded_sd["param_groups"],
"state": full_state,
}
else:
return {}
def load_from_full_optimizer_state_dict(
opt: Optimizer,
full_sd: Dict[str, Any],
device: torch.device,
) -> Dict[str, Any]:
"""
Converting full optimizer state to sharded state dict
and loading it into optimizer
"""
PARAMS = "params" # noqa: N806
_init_optim_state(opt)
param_groups = opt.state_dict()["param_groups"]
state = opt.state_dict()["state"]
full_param_groups = full_sd["param_groups"]
full_state = full_sd["state"]
for param_group, full_param_group in zip(param_groups, full_param_groups):
for key, value in full_param_group.items():
if key == PARAMS:
continue
param_group[key] = value
for pid, full_pid in zip(param_group[PARAMS], full_param_group[PARAMS]):
if pid not in state:
continue
param_state = state[pid]
full_param_state = full_state[full_pid]
for attr, full_tensor in full_param_state.items():
sharded_tensor = param_state[attr]
if isinstance(sharded_tensor, DTensor):
# exp_avg is DTensor
param_state[attr] = distribute_tensor(
full_tensor,
sharded_tensor.device_mesh,
sharded_tensor.placements,
)
else:
# step is plain tensor
param_state[attr] = full_tensor
opt.load_state_dict(
{
"param_groups": param_groups,
"state": state,
}
)
def get_shard_conditions(
name: str,
module: nn.Module,
names_to_match: Optional[List[str]] = None,
*args,
**kwargs,
) -> bool:
"""
Returs True for layers named {}.layers.i or layers that exactly match names_to_match, otherwise,
returns False. This is a helper function for sharding a model with FSDP.
In :func:`~torchtune.training.shard_model`, we iterate over the model's named modules
and apply fully_shard using this condition.
As part of our sharding strategy, we want each layer to be sharded separately, as this is
generally efficient. We may also want to shard certain modules that are not layers, such as
the embedding module.
#TODO: a more robust way would be to shard on the module type, not the name.
Args:
name (str): Name of the module.
module (nn.Module): Module to be sharded.
names_to_match (Optional[List[str]]): List of names to match, if any.
*args: Variable length argument list to be passed to the Embedding module.
**kwargs: Arbitrary keyword arguments to be passed to the Embedding module.
Returns:
bool: True if the module name matches the condition, False otherwise.
Examples:
>>> names_to_match = ["embedding"]
>>> layer_names = ["layers.0", "decoder.layers.1", "encoder.layers.2.attention",
"my_wrapper.layer.1.something", "embedding"]
>>> matches = []
>>> for name in layer_names:
>>> if shard_condition_is_layer_or_match(name, None): matches.append(name)
>>> print(matches)
>>> ["layers.0", "decoder.layers.1", "embedding"]
"""
if names_to_match and name in names_to_match:
return True
name_list = name.split(".")
if len(name_list) >= 2:
return name_list[-2] == "layers" and str.isdigit(name_list[-1])
return False
def shard_model(
model: TransformerDecoder,
shard_conditions: List[Callable[[str, nn.Module], bool]],
*,
cpu_offload: bool,
reshard_after_forward: bool = True,
) -> None:
"""
Utility to shard a model with FSDP using the PyTorch Distributed fully_shard API.
This method will over the model's named modules from the bottom-up and apply shard modules
based on whether they meet any of the criteria from shard_conditions.
Args:
model (TransformerDecoder): Model to shard with FSDP.
shard_conditions (List[Callable[[str, nn.Module], bool]]): A list of functions to determine
which modules to shard with FSDP. Each function should take module name (relative to root)
and the module itself, returning True if FSDP should shard the module and False otherwise.
If any of shard_conditions return True for a given module, it will be sharded by FSDP.
cpu_offload (bool): If set to True, FSDP will offload parameters, gradients, and optimizer
states to CPU.
reshard_after_forward (bool): Whether to reshard parameters and buffers after
the forward pass. Setting this to True corresponds to the FULL_SHARD sharding strategy
from FSDP1, while setting it to False corresponds to the SHARD_GRAD_OP sharding strategy.
Raises:
ValueError: If no layer modules were sharded, indicating that no shard_condition was triggered.
"""
fsdp_kwargs = {"reshard_after_forward": reshard_after_forward}
if cpu_offload:
fsdp_kwargs["offload_policy"] = CPUOffloadPolicy()
# Shard the model with FSDP, iterating in reverse to start with
# lowest-level modules first
num_layers_sharded = 0
for n, m in reversed(list(model.named_modules())):
if any([shard_condition(n, m) for shard_condition in shard_conditions]):
fully_shard(m, **fsdp_kwargs)
num_layers_sharded += 1
if num_layers_sharded == 0:
raise ValueError(
"No layer modules were sharded. Please check if shard conditions are working as expected."
)
# Finally shard the entire model to account for any stragglers
fully_shard(model, **fsdp_kwargs)