Rate this Page

PyTorch Symmetric Memory#

Created On: Oct 24, 2025 | Last Updated On: Oct 24, 2025

Note

torch.distributed._symmetric_memory is currently in alpha state and under development. API changes may be possible.

Why Symmetric Memory?#

With rapidly evolving parallelization techniques, existing frameworks and libraries often struggle to keep up, and developers increasingly rely on custom implementations directly scheduling communications and computations. In recent years we’ve witnessed a shift from primarily relying on one-dimensional data-parallelism techniques to multi-dimensional parallelism ones. The latter have different latency requirements for different types of communications and thus require fine-grained overlapping of compute and communications.

To minimize compute interference, they also require the use of copy engines and network interface cards (NICs) to drive communication. Network transport protocols such as remote direct memory access (RDMA) enhance the performance by enabling direct, high-speed, and low-latency communication between processors and memory. This increase in variety indicates the need for finer-grained communication primitives than are offered today by high-level collective APIs, ones that would enable developers to implement specific algorithms tailored for their use cases, such as low-latency collectives, fine-grained compute-communications overlap, or custom fusions.

Furthermore, today’s advanced AI systems connect GPUs with high-bandwidth links (such as NVLinks, InfiniBand or RoCE), making GPU global memory directly accessible to peers. Such connections present a great opportunity for programmers to program the system as a single, gigantic GPU with vast accessible memory, instead of programming singular “GPU islands.”

In this document, we will show how you can use PyTorch Symmetric Memory to program modern GPU systems as a “single GPU” and achieve fine-grained remote access.

What PyTorch Symmetric Memory unlocks?#

PyTorch Symmetric Memory unlocks three new capabilities:

  • Customized communication patterns: Increased flexibility in kernel writing allows developers to write custom kernels that implement their custom computations and communications, directly tailored to the need of the application. It will also be straightforward to add support for new data types along with the special compute that those data types might require, even if it’s not present yet in the standard libraries.

  • In-kernel compute-comm fusion: Device-initiated communication capability allows developers to write kernels with both computation and communication instructions, allowing for the fusion of computation and data movement in the smallest possible granularity.

  • Low-latency remote access: Network transport protocols like RDMA enhance the performance of symmetric memory in networked environments by enabling direct, high-speed, and low-latency communication between processors and memory. RDMA eliminates the overhead associated with the traditional network stack and CPU involvement. It also offloads data transfer from the compute to the NICs, freeing up compute resources for computational tasks.

Next, we will show you how PyTorch Symmetric Memory (SymmMem) enables new applications with the above capabilities.

A “Hello World” example#

The PyTorch SymmMem programming model involves two key elements:

  • creating symmetric tensors

  • creating SymmMem kernels

To create symmetric tensors, one can use the torch.distributed._symmetric_memory package:

import torch.distributed._symmetric_memory as symm_mem

t = symm_mem.empty(128, device=torch.device("cuda", rank))
hdl = symm_mem.rendezvous(t, group)

The symm_mem.empty function creates a tensor that is backed by a symmetric memory allocation. The rendezvous function establishes a rendezvous with peers in the group, and returns a handle to the symmetric memory allocation. The handle provides method to access information related to the symmetric memory allocation, such as pointers to symmetric buffer on peer ranks, multicast pointer (if supported), and signal pads.

The empty and rendezvous functions must be called in the same order on all ranks in the group.

Then, collectives can be called on these tensors. For example, to perform a one-shot all-reduce:

# Most SymmMem ops are under the torch.ops.symm_mem namespace
torch.ops.symm_mem.one_shot_all_reduce(t, "sum", group)

Please note that torch.ops.symm_mem is an “op namespace” instead of a python module. Therefore, you can’t import it by import torch.ops.symm_mem, neither can you import an op by from torch.ops.symm_mem import one_shot_all_reduce. You can call the op directly as in the example above.

Write your own kernel#

To write your own kernel doing communications with symmetric memory, you’ll need access to the addresses of mapped peer buffers and access to signal pads that are required for synchronization. In the kernel you’ll also need to perform correct synchronizations to make sure that peers are ready for communication, and signal to them that this GPU is ready.

PyTorch Symmetric Memory provides CUDA Graph-compatible synchronization primitives that operate on the signal pad accompanying each symmetric memory allocation. Kernels using symmetric memory can be written both in CUDA and in Triton. Here’s an example allocating symmetric tensor and exchanging handles:

import torch.distributed._symmetric_memory as symm_mem

dist.init_process_group()
rank = dist.get_rank()

# Allocate a tensor
t = symm_mem.empty(4096, device=f"cuda:{rank}")
# Establish symmetric memory and obtain the handle
hdl = symm_mem.rendezvous(t, dist.group.WORLD)

Access to buffer pointers, multimem pointer, and signal pads is provided via:

hdl.buffer_ptrs
hdl.multicast_ptr
hdl.signal_pad_ptrs

Data pointed to by buffer_ptrs can be accessed just like regular local data, and any necessary compute can also be performed in the usual ways. As with local data, you can and should use vectorized accesses to improve efficiency.

Symmetric memory is especially convenient for writing kernels in Triton. While previously Triton removed the barriers to writing efficient CUDA code, now communications can be added easily to Triton kernels. The kernel below demonstrates a low-latency, all-reduce kernel written in Triton.

@triton.jit
def one_shot_all_reduce_kernel(
    buf_tuple,
    signal_pad_ptrs,
    output_ptr,
    numel: tl.constexpr,
    rank: tl.constexpr,
    world_size: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
):
    ptx_utils.symm_mem_sync(
        signal_pad_ptrs, None, rank, world_size, hasSubsequenceMemAccess=True
    )

    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE

    while block_start < numel:
        offsets = block_start + tl.arange(0, BLOCK_SIZE)
        mask = offsets < numel
        acc = tl.zeros((BLOCK_SIZE,), dtype=tl.bfloat16)

        for i in tl.static_range(world_size):
            buffer_rank = buf_tuple[i]
            x = tl.load(buffer_rank + offsets, mask=mask)
            acc += x

        tl.store(output_ptr + offsets, acc, mask=mask)
        block_start += tl.num_programs(axis=0) * BLOCK_SIZE

    ptx_utils.symm_mem_sync(
        signal_pad_ptrs, None, rank, world_size, hasPreviousMemAccess=True
    )

Synchronizations at the beginning and the end of the kernel above guarantee that all the processes see consistent data. The bulk of the kernel is recognizable Triton code, and Triton will optimize it behind the scene, making sure memory accesses are performed in an efficient way with vectorization and unrolling. As with all Triton kernels, it is easily modifiable to add extra computations or change the communication algorithm. Visit https://github.com/meta-pytorch/kraken/blob/main/kraken to see additional utilities and examples of using symmetric memory to implement common patterns in Triton.

Scale out#

Large language models distribute experts onto more than 8 GPUs, hence requiring multi-node access capability. NICs capable of RDMA come to help. In addition, software libraries such as NVSHMEM or rocSHMEM abstract away the programming difference between intra-node access and inter-node access with primitives that are slightly higher level than pointer access, such as put and get.

PyTorch provides NVSHMEM plugins to augment Triton kernels’ cross-node capabilities. As shown in the code snippet below, one can initiate a cross-node put command within the kernel.

import torch.distributed._symmetric_memory._nvshmem_triton as nvshmem
from torch.distributed._symmetric_memory._nvshmem_triton import requires_nvshmem

@requires_nvshmem
@triton.jit
def my_put_kernel(
    dest,
    src,
    nelems,
    pe,
):
    nvshmem.put(dest, src, nelems, pe)

The requires_nvshmem decorator is used to indicate that the kernel requires the NVSHMEM device library as an external dependency. When Triton compiles the kernel, the decorator will search your system paths for the NVSHMEM device library. If it is available, Triton will include the necessary device assembly to use the NVSHMEM functions.

API Reference#

torch.distributed._symmetric_memory.empty(*size: _int, dtype: _dtype | None = None, device: _device | None = None) Tensor[source]#
torch.distributed._symmetric_memory.empty(size: Sequence[_int], *, dtype: _dtype | None = None, device: _device | None = None) Tensor

Similar to torch.empty(). The returned tensor can be used by torch._distributed._symmetric_memory.rendezvous() to establish a symmetric memory tensor among participating processes.

Parameters

size (int...) – a sequence of integers defining the shape of the output tensor. Can be a variable number of arguments or a collection like a list or tuple.

Keyword Arguments
  • dtype (torch.dtype, optional) – the desired data type of returned tensor. Default: if None, uses a global default (see torch.set_default_dtype()).

  • device (torch.device, optional) – the desired device of returned tensor. Default: if None, uses the current device for the default tensor type (see torch.set_default_device()). device will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types.

torch.distributed._symmetric_memory.rendezvous(tensor, group) _SymmetricMemory[source]#

Establish a symmetric memory tensor among participating processes. This is a collective operation.

Parameters
  • tensor (torch.Tensor) – the local tensor used to establish the symmetric memory tensor. It must be allocated via torch._distributed._symmetric_memory.empty(). The shape, dtype, and device type must be identical across all participating processes.

  • group (Union[str, torch.distributed.ProcessGroup]) – The group identifying the participating processes. This can be either a group name or a process group object.

Return type

_SymmetricMemory

torch.distributed._symmetric_memory.is_nvshmem_available() bool[source]#

Check if NVSHMEM is available in current build and on current system.

Return type

bool

torch.distributed._symmetric_memory.set_backend(name)[source]#

Set the backend for symmetric memory allocation. This is a global setting and affects all subsequent calls to torch._distributed._symmetric_memory.empty(). Note that the backend cannot be changed once a symmetric memory tensor has been allocated.

Parameters

backend (str) – the backend for symmetric memory allocation. Currently, only “NVSHMEM”, “CUDA”, “NCCL” are supported.

torch.distributed._symmetric_memory.get_backend(device)[source]#

Get the backend for symmetric memory allocation for a given device. If not found, return None.

Parameters

device (torch.device or str) – the device for which to get the backend.

Return type

str | None

Op Reference#

Note

The following ops are hosted in the torch.ops.symm_mem namespace. You can call them directly via torch.ops.symm_mem.<op_name>.

torch.ops.symm_mem.multimem_all_reduce_(input: Tensor, reduce_op: str, group_name: str) Tensor#

Performs a multimem all-reduce operation on the input tensor. This operation requires hardware support for multimem operations. On NVIDIA GPUs, NVLink SHARP is required.

Parameters
  • input (Tensor) – Input tensor to perform all-reduce on. Must be symmetric.

  • reduce_op (str) – Reduction operation to perform. Currently only “sum” is supported.

  • group_name (str) – Name of the group to perform all-reduce on.

torch.ops.symm_mem.multimem_all_gather_out(input: Tensor, group_name: str, out: Tensor) Tensor#

Performs a multimem all-gather operation on the input tensor. This operation requires hardware support for multimem operations. On NVIDIA GPUs, NVLink SHARP is required.

Parameters
  • input (Tensor) – Input tensor to perform all-gather on.

  • group_name (str) – Name of the group to perform all-gather on.

  • out (Tensor) – Output tensor to store the result of the all-gather operation. Must be symmetric.

torch.ops.symm_mem.one_shot_all_reduce(input: Tensor, reduce_op: str, group_name: str) Tensor#

Performs a one-shot all-reduce operation on the input tensor.

Parameters
  • input (Tensor) – Input tensor to perform all-reduce on. Must be symmetric.

  • reduce_op (str) – Reduction operation to perform. Currently only “sum” is supported.

  • group_name (str) – Name of the group to perform all-reduce on.

torch.ops.symm_mem.one_shot_all_reduce_out(input: Tensor, reduce_op: str, group_name: str, out: Tensor) Tensor#

Performs a one-shot all-reduce operation based on the input tensor and writes the result to the output tensor.

Parameters
  • input (Tensor) – Input tensor to perform all-reduce on. Must be symmetric.

  • reduce_op (str) – Reduction operation to perform. Currently only “sum” is supported.

  • group_name (str) – Name of the group to perform all-reduce on.

  • out (Tensor) – Output tensor to store the result of the all-reduce operation. Can be a regular tensor.

torch.ops.symm_mem.two_shot_all_reduce_(input: Tensor, reduce_op: str, group_name: str) Tensor#

Performs a two-shot all-reduce operation on the input tensor.

Parameters
  • input (Tensor) – Input tensor to perform all-reduce on. Must be symmetric.

  • reduce_op (str) – Reduction operation to perform. Currently only “sum” is supported.

  • group_name (str) – Name of the group to perform all-reduce on.

torch.ops.symm_mem.all_to_all_vdev(input: Tensor, out: Tensor, in_splits: Tensor, out_splits_offsets: Tensor, group_name: str) None#

Performs an all-to-all-v operation using NVSHMEM, with split information provided on device.

Parameters
  • input (Tensor) – Input tensor to perform all-to-all on. Must be symmetric.

  • out (Tensor) – Output tensor to store the result of the all-to-all operation. Must be symmetric.

  • in_splits (Tensor) – Tensor containing splits of data to send to each peer. Must be symmetric. Must be of size (group_size,). The splits are in the unit of elements in the 1st dimension.

  • out_splits_offsets (Tensor) – Tensor containing the splits and offsets of data received from each peer. Must be symmetric. Must be of size (2, group_size). The rows are (in order): output splits and output offsets.

  • group_name (str) – Name of the group to perform all-to-all on.

torch.ops.symm_mem.all_to_all_vdev_2d(input: Tensor, out: Tensor, in_splits: Tensor, out_splits_offsets: Tensor, group_name: str[, major_align: int = None]) None#

Perform a 2D all-to-all-v operation using NVSHMEM, with split information provided on device. In Mixture of Experts models, this operation can be used to dispatch tokens.

Parameters
  • input (Tensor) – Input tensor to perform all-to-all on. Must be symmetric.

  • out (Tensor) – Output tensor to store the result of the all-to-all operation. Must be symmetric.

  • in_splits (Tensor) – Tensor containing the splits of data to send to each expert. Must be symmetric. Must be of size (group_size * ne,), where ne is the number of experts per rank. The splits are in the unit of elements in the 1st dimension.

  • out_splits_offsets (Tensor) – Tensor containing the splits and offsets of data received from each peer. Must be symmetric. Must be of size (2, group_size * ne). The rows are (in order): output splits and output offsets.

  • group_name (str) – Name of the group to perform all-to-all on.

  • major_align (int) – Optional alignment for the major dimension of the output chunk for each expert. If not provided, the alignment is assumed to be 1. Any alignment adjustment will be reflected in the output offsets.

A 2D AllToAllv shuffle is illustrated below: (world_size = 2, ne = 2, total number of experts = 4):

Source: |       Rank 0      |       Rank 1      |
        | c0 | c1 | c2 | c3 | d0 | d1 | d2 | d3 |

Dest  : |       Rank 0      |       Rank 1      |
        | c0 | d0 | c1 | d1 | c2 | d2 | c3 | d3 |

where each c_i / d_i are slices of the input tensor, targeting expert i, with length indicated by input splits. That is, the 2D AllToAllv shuffle achieves a transpose from rank-major order at input to expert-major order at output.

If major_align is not 1, the output offsets of c1, c2, c3 will be up-aligned to this value. For example, if c0 has length 5 and d0 has length 7 (making a total of 12), and if the major_align is set to 16, the output offset of c1 will be 16. Similar for c2 and c3. This value has no effect on the offset of the minor dimension, i.e. d0, d1, d2 and d3. Note: since cutlass does not support empty bins, we set the aligned length to major_align if it is 0. See pytorch/pytorch#152668.

torch.ops.symm_mem.all_to_all_vdev_2d_offset(Tensor input, Tensor out, Tensor in_splits_offsets, Tensor out_splits_offsets, str group_name) None#

Perform a 2D AllToAllv shuffle operation, with input split and offset information provided on device. The input offsets are not required to be exact prefix sum of the input splits, i.e. paddings are allowed between the split chunks. The paddings, however, will not be transferred to peer ranks.

In Mixture of Experts models, this operation can be used to combine tokens processed by experts on parallel ranks. This operation can be viewed as an “reverse” operation to the all_to_all_vdev_2d operation (which shuffles tokens to experts).

Parameters
  • input (Tensor) – Input tensor to perform all-to-all on. Must be symmetric.

  • out (Tensor) – Output tensor to store the result of the all-to-all operation. Must be symmetric.

  • in_splits_offsets (Tensor) – Tensor containing the splits and offsets of data to send to each expert. Must be symmetric. Must be of size (2, group_size * ne), where ne is the number of experts. The rows are (in order): input splits and input offsets. The splits are in the unit of elements in the 1st dimension.

  • out_splits_offsets (Tensor) – Tensor containing the splits and offsets of data received from each peer. Must be symmetric. Must be of size (2, group_size * ne). The rows are (in order): output splits and output offsets.

  • group_name (str) – Name of the group to perform all-to-all on.