Automatic Mixed Precision#
Created On: Sep 30, 2025 | Last Updated On: Sep 30, 2025
Background#
Automatic Mixed Precision (AMP) enables the use of both single precision (32-bit) and half precision (16-bit) floating point types during training or inference.
Key components include:
Autocast: Automatically casts operations to lower-precision (e.g., float16 or bfloat16) to improve performance while maintaining accuracy.
Gradient Scaling: Dynamically scales gradients during backpropagation to prevent underflow when training with mixed precision.
Design#
Casting Strategy#
The CastPolicy
is used to define type conversion rules. Each enum value represents a set of type conversion requirements for a group of operators, ensuring consistent handling of operations that prioritize either precision or performance.
Policy |
Explanation |
---|---|
|
Cast all inputs to |
|
Cast all inputs to |
|
Execution in |
|
Append at::kFloat to the args and redispatch to the type-aware overload |
|
Promote all inputs to the “widest” dtype before execution. |
Operators Lists#
PyTorch defines a general list of operators for each of casting strategies mentioned above, as a reference for developers of new accelerators.
Implementation#
Python Integration#
Implement the get_amp_supported_dtype
method to return the data types supported by the new accelerator in the AMP context.
1def get_amp_supported_dtype():
2 return [torch.float16, torch.bfloat16]
3
4
C++ Integration#
This section shows how AMP registers autocast kernels for the AutocastPrivateUse1
dispatch key.
Register a fallback that makes unhandled ops fall through to their normal implementations.
Register specific aten kernels under
AutocastPrivateUse1
using theKERNEL_PRIVATEUSEONE
helper macro, which maps an op to the desired precision implementation (with enumat::autocast::CastPolicy
)
1TORCH_LIBRARY_IMPL(_, AutocastPrivateUse1, m) {
2 m.fallback(torch::CppFunction::makeFallthrough());
3}
1TORCH_LIBRARY_IMPL(aten, AutocastPrivateUse1, m) {
2 // lower_precision_fp
3 KERNEL_PRIVATEUSEONE(mm, lower_precision_fp)
4
5 // fp32
6 KERNEL_PRIVATEUSEONE(asin, fp32)
7
8 m.impl(
9 TORCH_SELECTIVE_NAME("aten::binary_cross_entropy"),
10 TORCH_FN((&binary_cross_entropy_banned)));
11}