Rate this Page

torch.nn.functional.scaled_mm#

torch.nn.functional.scaled_mm(mat_a, mat_b, scale_a, scale_recipe_a, scale_b, scale_recipe_b, swizzle_a=None, swizzle_b=None, bias=None, output_dtype=torch.bfloat16, contraction_dim=(), use_fast_accum=False)[source]#
scaled_mm(mat_a, mat_b, scale_a, scale_recipe_a, scale_b, scale_recipe_b, swizzle_a, swizzle_b, bias, output_dtype,

contraction_dim, use_fast_accum)

Applies a scaled matrix-multiply, mm(mat_a, mat_b) where the scaling of mat_a and mat_b are described by scale_recipe_a and scale_recipe_b respectively.

Parameters
  • scale_a (torch.Tensor | list[torch.Tensor]) – Tensor containing decoding scaling factors for mat_a

  • scale_recipe_a (torch.nn.functional._ScalingType | list[torch.nn.functional._ScalingType]) – Enum describing how mat_a has been scaled

  • scale_b (torch.Tensor | list[torch.Tensor]) – Tensor containing decoding scaling factors for mat_b

  • scale_recipe_b (torch.nn.functional._ScalingType | list[torch.nn.functional._ScalingType]) – Enum describing how mat_b has been scaled

  • swizzle_a (Optional[Union[_SwizzleType, list[torch.nn.functional._SwizzleType]]]) – Enum describing the swizzling pattern (if any) of scale_a

  • swizzle_b (Optional[Union[_SwizzleType, list[torch.nn.functional._SwizzleType]]]) – Enum describing the swizzling pattern (if any) of scale_b

  • bias (Optional[Tensor]) – optional bias term to be added to the output

  • output_dtype (Optional[dtype]) – dtype used for the output tensor

  • contraction_dim (list[int] | tuple[int]) – describe which dimensions are KK in the matmul.

  • use_fast_accum (bool) – enable/disable tensor-core fast accumulation (Hopper-GPUs only)

Return type

Tensor