Model Parallel¶
DistributedModelParallel
is the main API for distributed training with TorchRec optimizations.
- class torchrec.distributed.model_parallel.DistributedModelParallel(module: Module, env: Optional[ShardingEnv] = None, device: Optional[device] = None, plan: Optional[ShardingPlan] = None, sharders: Optional[List[ModuleSharder[Module]]] = None, init_data_parallel: bool = True, init_parameters: bool = True, data_parallel_wrapper: Optional[DataParallelWrapper] = None, model_tracker_config: Optional[ModelTrackerConfig] = None)¶
Entry point to model parallelism.
- Parameters:
module (nn.Module) – module to wrap.
env (Optional[ShardingEnv]) – sharding environment that has the process group.
device (Optional[torch.device]) – compute device, defaults to cpu.
plan (Optional[ShardingPlan]) – plan to use when sharding, defaults to EmbeddingShardingPlanner.collective_plan().
sharders (Optional[List[ModuleSharder[nn.Module]]]) – ModuleSharders available to shard with, defaults to EmbeddingBagCollectionSharder().
init_data_parallel (bool) – data-parallel modules can be lazy, i.e. they delay parameter initialization until the first forward pass. Pass True to delay initialization of data parallel modules. Do first forward pass and then call DistributedModelParallel.init_data_parallel().
init_parameters (bool) – initialize parameters for modules still on meta device.
data_parallel_wrapper (Optional[DataParallelWrapper]) – custom wrapper for data parallel modules.
model_tracker_config (Optional[DeltaTrackerConfig]) – config for model tracker.
Example:
@torch.no_grad() def init_weights(m): if isinstance(m, nn.Linear): m.weight.fill_(1.0) elif isinstance(m, EmbeddingBagCollection): for param in m.parameters(): init.kaiming_normal_(param) m = MyModel(device='meta') m = DistributedModelParallel(m) m.apply(init_weights)
- copy(device: device) DistributedModelParallel ¶
Recursively copy submodules to new device by calling per-module customized copy process, since some modules needs to use the original references (like ShardedModule for inference).
- forward(*args, **kwargs) Any ¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- get_delta(consumer: Optional[str] = None) Dict[str, DeltaRows] ¶
Returns the delta rows for the given consumer.
- get_model_tracker() ModelDeltaTracker ¶
Returns the model tracker if it exists.
- init_data_parallel() None ¶
See init_data_parallel c-tor argument for usage. It’s safe to call this method multiple times.
- load_state_dict(state_dict: OrderedDict[str, Tensor], prefix: str = '', strict: bool = True) _IncompatibleKeys ¶
Copy parameters and buffers from
state_dict
into this module and its descendants.If
strict
isTrue
, then the keys ofstate_dict
must exactly match the keys returned by this module’sstate_dict()
function.Warning
If
assign
isTrue
the optimizer must be created after the call toload_state_dict
unlessget_swap_module_params_on_conversion()
isTrue
.- Parameters:
state_dict (dict) – a dict containing parameters and persistent buffers.
strict (bool, optional) – whether to strictly enforce that the keys in
state_dict
match the keys returned by this module’sstate_dict()
function. Default:True
assign (bool, optional) – When set to
False
, the properties of the tensors in the current module are preserved whereas setting it toTrue
preserves properties of the Tensors in the state dict. The only exception is therequires_grad
field ofParameter
for which the value from the module is preserved. Default:False
- Returns:
missing_keys
is a list of str containing any keys that are expectedby this module but missing from the provided
state_dict
.
unexpected_keys
is a list of str containing the keys that are notexpected by this module but present in the provided
state_dict
.
- Return type:
NamedTuple
withmissing_keys
andunexpected_keys
fields
Note
If a parameter or buffer is registered as
None
and its corresponding key exists instate_dict
,load_state_dict()
will raise aRuntimeError
.
- property module: Module¶
Property to directly access sharded module, which will not be wrapped in DDP, FSDP, DMP, or any other parallelism wrappers.
- named_buffers(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[Tuple[str, Tensor]] ¶
Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
- Parameters:
prefix (str) – prefix to prepend to all buffer names.
recurse (bool, optional) – if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True.
remove_duplicate (bool, optional) – whether to remove the duplicated buffers in the result. Defaults to True.
- Yields:
(str, torch.Tensor) – Tuple containing the name and buffer
Example:
>>> # xdoctest: +SKIP("undefined vars") >>> for name, buf in self.named_buffers(): >>> if name in ['running_var']: >>> print(buf.size())
- named_parameters(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[Tuple[str, Parameter]] ¶
Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
- Parameters:
prefix (str) – prefix to prepend to all parameter names.
recurse (bool) – if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.
remove_duplicate (bool, optional) – whether to remove the duplicated parameters in the result. Defaults to True.
- Yields:
(str, Parameter) – Tuple containing the name and parameter
Example:
>>> # xdoctest: +SKIP("undefined vars") >>> for name, param in self.named_parameters(): >>> if name in ['bias']: >>> print(param.size())
- reshard(sharded_module_fqn: str, changed_shard_to_params: Dict[str, ParameterSharding]) None ¶
Reshards an already-sharded module in the DMP given a set of ParameterShardings to change placements.
This method allows you to dynamically change the sharding strategy for a specific module without recreating the entire DMP. It’s particularly useful for: 1. Adapting to changing requirements during training 2. Implementing progressive sharding strategies 3. Rebalancing load across devices 4. A/B Testing different sharding plans
- Parameters:
path_to_sharded_module (str) – The path to the sharded module in the DMP. For example, “sparse.ebc”.
changed_shard_to_params (Dict[str, ParameterSharding]) – A dictionary mapping parameter names to their new ParameterSharding configurations. Includes only the shards that needs to be moved.
Example
``` # Original sharding plan might have table sharded across 2 GPUs original_plan = {
- “table_0’: ParameterSharding(
sharding_type=”table_wise”, ranks=[0, 1, 2, 3], sharding_spec=EnumerableShardingSpec(…)
)
}
# New sharding plan to shard across 4 GPUs new_plan = {
- “weight”: ParameterSharding(
sharding_type=”table_wise”, ranks=[0, 1, 2, 3], sharding_spec=EnumerableShardingSpec(…)
)
}
# Helper function for only selecting the delta between original and new plan changed_sharding_params = output_sharding_plan_delta(new_plan)
# Reshard the module and redistribute the tensors model.reshard(“embedding_module”, changed_sharding_params) ```
Notes
The sharder for the module must implement a reshard method
Resharding involves redistributing tensor data across devices, which can be expensive
After resharding, the optimizer state is maintained for the module
The sharding plan is updated to reflect the new configuration
- state_dict(destination: Optional[Dict[str, Any]] = None, prefix: str = '', keep_vars: bool = False) Dict[str, Any] ¶
Return a dictionary containing references to the whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to
None
are not included.Note
The returned object is a shallow copy. It contains references to the module’s parameters and buffers.
Warning
Currently
state_dict()
also accepts positional arguments fordestination
,prefix
andkeep_vars
in order. However, this is being deprecated and keyword arguments will be enforced in future releases.Warning
Please avoid the use of argument
destination
as it is not designed for end-users.- Parameters:
destination (dict, optional) – If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an
OrderedDict
will be created and returned. Default:None
.prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default:
''
.keep_vars (bool, optional) – by default the
Tensor
s returned in the state dict are detached from autograd. If it’s set toTrue
, detaching will not be performed. Default:False
.
- Returns:
a dictionary containing a whole state of the module
- Return type:
dict
Example:
>>> # xdoctest: +SKIP("undefined vars") >>> module.state_dict().keys() ['bias', 'weight']