Rate this Page

Automatic Mixed Precision#

Created On: Sep 30, 2025 | Last Updated On: Sep 30, 2025

Background#

Automatic Mixed Precision (AMP) enables the use of both single precision (32-bit) and half precision (16-bit) floating point types during training or inference.

Key components include:

  • Autocast: Automatically casts operations to lower-precision (e.g., float16 or bfloat16) to improve performance while maintaining accuracy.

  • Gradient Scaling: Dynamically scales gradients during backpropagation to prevent underflow when training with mixed precision.

Design#

Casting Strategy#

The CastPolicy is used to define type conversion rules. Each enum value represents a set of type conversion requirements for a group of operators, ensuring consistent handling of operations that prioritize either precision or performance.

Policy

Explanation

lower_precision_fp

Cast all inputs to lower_precision_fp before execute the op.

fp32

Cast all inputs to at::kFloat before running the op.

fp32_set_opt_dtype

Execution in at::kFloat, while respecting user-specified output dtype if provided.

fp32_append_dtype

Append at::kFloat to the args and redispatch to the type-aware overload

promote

Promote all inputs to the “widest” dtype before execution.

Operators Lists#

PyTorch defines a general list of operators for each of casting strategies mentioned above, as a reference for developers of new accelerators.

Policy

Operators List

lower_precision_fp

List Link

fp32

List Link

fp32_set_opt_dtype

List Link

fp32_append_dtype

List Link

promote

List Link

Implementation#

Python Integration#

Implement the get_amp_supported_dtype method to return the data types supported by the new accelerator in the AMP context.

1def get_amp_supported_dtype():
2    return [torch.float16, torch.bfloat16]
3
4

C++ Integration#

This section shows how AMP registers autocast kernels for the AutocastPrivateUse1 dispatch key.

  • Register a fallback that makes unhandled ops fall through to their normal implementations.

  • Register specific aten kernels under AutocastPrivateUse1 using the KERNEL_PRIVATEUSEONE helper macro, which maps an op to the desired precision implementation (with enum at::autocast::CastPolicy)

1TORCH_LIBRARY_IMPL(_, AutocastPrivateUse1, m) {
2  m.fallback(torch::CppFunction::makeFallthrough());
3}
 1TORCH_LIBRARY_IMPL(aten, AutocastPrivateUse1, m) {
 2  // lower_precision_fp
 3  KERNEL_PRIVATEUSEONE(mm, lower_precision_fp)
 4
 5  // fp32
 6  KERNEL_PRIVATEUSEONE(asin, fp32)
 7
 8  m.impl(
 9      TORCH_SELECTIVE_NAME("aten::binary_cross_entropy"),
10      TORCH_FN((&binary_cross_entropy_banned)));
11}