Rate this Page

torch.nn.functional.scaled_grouped_mm#

torch.nn.functional.scaled_grouped_mm(mat_a, mat_b, scale_a, scale_recipe_a, scale_b, scale_recipe_b, swizzle_a=None, swizzle_b=None, bias=None, offs=None, output_dtype=torch.bfloat16, contraction_dim=(), use_fast_accum=False)[source]#
scaled_grouped_mm(mat_a, mat_b, scale_a, scale_recipe_a, scale_b, scale_recipe_b, swizzle_a, swizzle_b, bias, offs,

output_dtype, use_fast_accum)

Applies a grouped scaled matrix-multiply, grouped_mm(mat_a, mat_b) where the scaling of mat_a and mat_b are described by scale_recipe_a and scale_recipe_b respectively.

Parameters:
  • scale_a (Tensor | list[Tensor]) – Tensor containing decoding scaling factors for mat_a

  • scale_recipe_a (_ScalingType | list[_ScalingType]) – Enum describing how mat_a has been scaled

  • scale_b (Tensor | list[Tensor]) – Tensor containing decoding scaling factors for mat_b

  • scale_recipe_b (_ScalingType | list[_ScalingType]) – Enum describing how mat_b has been scaled

  • swizzle_a (_SwizzleType | list[_SwizzleType] | None) – Enum describing the swizzling pattern (if any) of scale_a

  • swizzle_b (_SwizzleType | list[_SwizzleType] | None) – Enum describing the swizzling pattern (if any) of scale_b

  • bias (Tensor | None) – optional bias term to be added to the output

  • offs (Tensor | None) – optional offsets into the source tensors denoting group start indices

  • output_dtype (dtype | None) – dtype used for the output tensor

  • contraction_dim (list[int] | tuple[int]) – describe which dimensions are KK in the matmul.

  • use_fast_accum (bool) – enable/disable tensor-core fast accumulation (Hopper-GPUs only)

Return type:

Tensor