torch.nn.attention.register_flash_attention_impl#
- torch.nn.attention.register_flash_attention_impl(impl, *, register_fn)[source]#
Register the callable that activates a flash attention impl.
Note
This function is intended for SDPA backend providers to register their implementations. End users should use
activate_flash_attention_impl()to activate a registered implementation.- Parameters:
impl (str | Literal['FA4']) – Implementation identifier (e.g.,
"FA4").register_fn (Callable[[...], FlashAttentionHandle | None]) – Callable that performs the actual dispatcher registration. This function will be invoked by
activate_flash_attention_impl()and should register custom kernels with the PyTorch dispatcher. It may optionally return a handle implementingFlashAttentionHandleto keep any necessary state alive.
Example
>>> def my_impl_register(module_path: str = "my_flash_impl"): ... # Register custom kernels with torch dispatcher ... pass >>> register_flash_attention_impl( ... "MyImpl", register_fn=my_impl_register ... )