Rate this Page

torch.nn.attention.register_flash_attention_impl#

torch.nn.attention.register_flash_attention_impl(impl, *, register_fn)[source]#

Register the callable that activates a flash attention impl.

Note

This function is intended for SDPA backend providers to register their implementations. End users should use activate_flash_attention_impl() to activate a registered implementation.

Parameters:
  • impl (str | Literal['FA4']) – Implementation identifier (e.g., "FA4").

  • register_fn (Callable[[...], FlashAttentionHandle | None]) – Callable that performs the actual dispatcher registration. This function will be invoked by activate_flash_attention_impl() and should register custom kernels with the PyTorch dispatcher. It may optionally return a handle implementing FlashAttentionHandle to keep any necessary state alive.

Example

>>> def my_impl_register(module_path: str = "my_flash_impl"):
...     # Register custom kernels with torch dispatcher
...     pass  
>>> register_flash_attention_impl(
...     "MyImpl", register_fn=my_impl_register
... )