fp8_fa3_sdpa#
- torchao.prototype.attention.fp8_fa3.attention.fp8_fa3_sdpa(query: Tensor, key: Tensor, value: Tensor, attn_mask: Tensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, scale: float | None = None, enable_gqa: bool = False, hadamard: str = 'NONE', *, backend_name: str = 'FA3') Tensor#
FP8 SDPA shared by all backends.
The correct flash attention implementation (e.g. FA3) must be activated before calling this function. The high-level
apply_low_precision_attentionAPI handles this automatically.Input/output layout: [B, H, S, D].