torch.nn.functional.grouped_mm#
- torch.nn.functional.grouped_mm(mat_a, mat_b, *, offs=None, bias=None, out_dtype=None)[source]#
Computes a grouped matrix multiply that shares weight shapes across experts but allows jagged token counts per expert, which is common in Mixture-of-Experts (MoE) layers. Both
mat_aandmat_bmust be 2D or 3D tensors that already satisfy the physical layout restrictions of grouped GEMM kernels (e.g., row-majormat_aand column-majormat_bfor FP8 inputs). Inputs are currently expected to betorch.bfloat16values on CUDA devices with .- Parameters:
mat_a (Tensor) – Left operand. When 2D, its leading dimension is sliced into groups according to
offs. When 3D, its first dimension enumerates the groups directly andoffsmust beNone.mat_b (Tensor) – Right operand. When both operands are 2D (e.g., MoE weight-gradient updates), the trailing dimension of
mat_aand the leading dimension ofmat_bare partitioned according to the sameoffstensor. For the common forward pass (out = input @ weight.T)mat_bis 3D with shape(num_groups, N, K).offs (Tensor | None) – Optional 1D tensor of monotonically increasing
int32offsets that delimit the jagged dimension of any 2D operand.offs[i]marks the end of groupiandoffs[-1]must be strictly less than the total length of that operand’s sliced dimension; elements beyondoffs[-1]are ignored.bias (Tensor | None) – Optional tensor that is added to the grouped outputs. Bias is not jagged and must be broadcastable to the result shape of each group.
out_dtype (dtype | None) – Optional dtype that controls the accumulation/output dtype. Passing
torch.float32accumulates BF16 inputs in FP32 while keeping the grouped GEMM API non-differentiable.
- Returns:
A tensor containing the concatenated results of each per-group GEMM with shape inferred from the operands and
offs.- Return type: