Distributed Training Integration#
Created On: May 18, 2026 | Last Updated On: May 18, 2026
Background#
Distributed training allows accelerators to scale workloads across multiple devices and nodes by coordinating collective communication (e.g., allreduce, broadcast, allgather) through a ProcessGroup backend. PyTorch ships built-in backends such as NCCL for CUDA and Gloo for CPU, but the framework exposes a registration mechanism that allows out-of-tree accelerator vendors to plug in their own collective communication library without modifying upstream code.
The integration surface can be broken down into three layers:
C++ Backend implementation – A subclass of
c10d::Backendthat implements the collective, point-to-point, and synchronization operations.Python bindings – Expose the C++ backend class to Python via pybind11.
Backend registration – Register the backend with
torch.distributed.Backend.register_backend()so thatinit_process_groupcan discover and instantiate it.
Note
OpenReg (torch_openreg) is PyTorch’s official reference implementation for out-of-tree accelerator integration. It ships a minimal distributed backend called OCCL (OpenReg Collective Communications Library) that demonstrates the full ProcessGroup integration. All code examples in this chapter reference the OCCL implementation.
Before You Start#
This guide covers ProcessGroup backend integration only – how to register a custom collective communication backend with torch.distributed. It does not cover full-stack integration with higher-level APIs such as DDP, FSDP, or other distributed training strategies.
Before following this guide, make sure you have:
An importable
torch_xxxextension package that registers your device viaPrivateUse1. See the earlier chapters in this guide for device registration, operators, and runtime hooks.A collective communication library (CCL) that provides implementations of basic collectives such as
allreduceandbroadcastfor your device. The CCL can be vendor-provided (e.g., NCCL for NVIDIA, HCCL for Huawei) or a custom implementation.
Design#
This section describes the interfaces and concepts involved in backend registration.
Registration API#
The primary entry point for OOT backend registration is Backend.register_backend():
Parameter |
Type |
Description |
|---|---|---|
|
|
Backend name, e.g. |
|
|
Factory function that creates a backend instance (see signature below). |
|
|
If |
|
|
Device types supported by this backend, e.g. |
When devices is specified, the backend is automatically associated with those device types. This means init_process_group() can resolve the correct backend when the user passes a device_id argument without explicitly naming a backend.
Factory Function Signature#
The factory function receives different arguments depending on extended_api:
Mode |
Signature |
|---|---|
Standard (default) |
|
Extended API |
|
The standard mode is sufficient for most backends. The extended API provides additional context such as group_id and global_ranks_in_group.
Backend Operations#
The c10d::Backend base class defines virtual methods for collective, point-to-point, and synchronization operations. Each operation returns a c10::intrusive_ptr<Work> that represents the asynchronous operation. For backends with synchronous operations, the Work object can be immediately completed.
Minimal Required Operations#
To get a working backend that supports basic distributed training, implement the following operations at minimum:
Category |
Operations |
|---|---|
Collective |
|
Synchronization |
|
These cover the core communication patterns used by DDP and other common distributed workflows.
Extended Operations#
For broader compatibility with advanced distributed strategies (e.g., FSDP, model parallelism, pipeline parallelism), implement the full set of operations:
Category |
Operations |
|---|---|
Collective |
|
Point-to-Point |
|
See Backend.hpp for the full list of virtual methods and their signatures.
Optional Capabilities#
Backends can advertise optional capabilities by overriding the following methods:
Method |
Default |
Description |
|---|---|---|
|
|
Process group splitting support |
|
|
Coalesced collective operations |
Implementation#
This section walks through the concrete steps to implement and register a backend, using the OCCL reference implementation as an example. The implementation follows three steps:
Implement the C++ backend
Create Python bindings
Register the backend in Python
Step 1: Implement the C++ Backend#
Create a class that inherits from c10d::Backend and implements the required collective operations. The backend must also define:
A
Worksubclass that tracks asynchronous operation stateAn
Optionssubclass (inheriting fromBackend::Options) for backend-specific configuration
Work Object#
The Work subclass manages the lifecycle of an asynchronous collective operation. For a minimal (synchronous) implementation, the work can be completed immediately in its constructor:
1 class DummyWork : public Work {
2 public:
3 DummyWork();
4
5 virtual ~DummyWork();
6 bool isCompleted() override;
7 bool isSuccess() const override;
8 bool wait(std::chrono::milliseconds timeout) override;
9 void synchronize() override;
10 void abort() override;
11 c10::intrusive_ptr<c10::ivalue::Future> getFuture() override;
12
13 protected:
14 friend class ProcessGroupOCCL;
15
16 private:
17 c10::intrusive_ptr<c10::ivalue::Future> future_;
18 };
For production backends, Work typically wraps an asynchronous handle from the vendor’s communication library (e.g., a stream event or request handle), and wait() blocks until the operation completes on the device.
Backend Class#
The backend class inherits from c10d::Backend and overrides the collective operations. Each method should validate that input tensors reside on the expected device type (e.g., PrivateUse1) and then dispatch to the vendor’s communication library. Key implementation details:
getBackendName()must return the same string used during Python registration (e.g.,"occl").Input validation – Each collective should verify tensor device types. The OCCL reference uses
CHECK_TENSORandCHECK_TENSOR_LISTmacros for this.Return value – All collectives return a
c10::intrusive_ptr<Work>.
See ProcessGroupOCCL.hpp and ProcessGroupOCCL.cpp for the full reference implementation.
Step 2: Python Bindings#
Expose the backend class to Python using pybind11. The OCCL reference places bindings in a dedicated init.cpp file, separate from the main extension module, and calls initProcessGroupBindings() from the module’s entry point:
1void initProcessGroupBindings(py::module& m) {
2 py::class_<c10d::ProcessGroupOCCL, c10d::Backend, c10::intrusive_ptr<c10d::ProcessGroupOCCL>>(m, "ProcessGroupOCCL")
3 .def(
4 py::init([](const c10::intrusive_ptr<::c10d::Store>& /*store*/,
5 int rank,
6 int size,
7 std::chrono::milliseconds /*timeout*/) {
8 return c10::make_intrusive<::c10d::ProcessGroupOCCL>(rank, size);
9 }),
10 py::arg("store"),
11 py::arg("rank"),
12 py::arg("size"),
13 py::arg("timeout") = std::chrono::milliseconds(30 * 60 * 1000));
14}
Important considerations:
The
py::class_template must listc10d::Backendas a base class and usec10::intrusive_ptras the holder, so that PyTorch recognizes the backend in its internal registry.The constructor is exposed directly via
py::initwith a lambda that forwards to the C++ constructor. This avoids the need for a separate factory function.Guard the bindings with
#if USE_DISTRIBUTEDto handle builds where distributed is disabled.
Step 3: Register the Backend in Python#
In the extension package’s __init__.py, register the backend with torch.distributed:
1if torch.distributed.is_available():
2 try:
3 from torch_openreg._C import ProcessGroupOCCL
4
5 def _create_occl_backend(store, rank, size, timeout):
6 return ProcessGroupOCCL(store, rank, size, timeout)
7
8 torch.distributed.Backend.register_backend(
9 "occl", _create_occl_backend, devices=["openreg"]
10 )
11 except Exception as e:
12 raise RuntimeError("Failed to register 'occl' process group backend.") from e
The Python side imports the pybind11-exposed ProcessGroupOCCL class and wraps it in a thin factory function that matches the signature expected by register_backend(). The call to Backend.register_backend() does the following:
Adds
"occl"toBackend.backend_list, making it a recognized backend name.Maps
"openreg"device type to the"occl"backend inBackend.default_device_backend_map.Stores the factory function so that
init_process_group()can call it whenbackend="occl"is specified.
Usage#
After registration, the backend integrates seamlessly with torch.distributed:
import torch
import torch.distributed as dist
# Import triggers autoload, which registers the "occl" backend
import torch_openreg
# Initialize process group – OCCL is auto-selected for openreg devices
dist.init_process_group(
backend="occl",
init_method="env://",
world_size=2,
rank=0,
)
# Use standard distributed APIs
tensor = torch.randn(4, device="openreg")
dist.all_reduce(tensor)
dist.destroy_process_group()
Alternatively, the backend name can be omitted if a device_id is provided – PyTorch resolves the backend from the device-to-backend mapping:
dist.init_process_group(
device_id=torch.device("openreg:0"),
init_method="env://",
world_size=2,
rank=0,
)
Multi-device Backend Strings#
PyTorch supports specifying different backends for different device types in a single process group using the "device:backend" format:
dist.init_process_group(
backend="cpu:gloo,openreg:occl",
init_method="env://",
world_size=2,
rank=0,
)
Testing#
Key testing considerations:
Verify that the backend appears in
dist.Backend.backend_listafter import.Confirm that
init_process_group/destroy_process_groupsucceeds.Test that collective operations accept tensors on the registered device and return completed
Workobjects.Use
MultiProcessTestCasefromtorch.testing._internal.common_distributedfor multi-process test execution.
See the OCCL test suite for a reference example.