diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 310c44457c27..82f7e1f8294e 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -616,78 +616,6 @@ def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None: raise -# ===== torch op registrations ===== -# Registrations are required for fullgraph tracing compatibility -# TODO: this is only required because the beta release FA3 does not have it. There is a PR adding -# this but it was never merged: https://github.com/Dao-AILab/flash-attention/pull/1590 -@_custom_op("_diffusers_flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda") -def _wrapped_flash_attn_3( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - softmax_scale: Optional[float] = None, - causal: bool = False, - qv: Optional[torch.Tensor] = None, - q_descale: Optional[torch.Tensor] = None, - k_descale: Optional[torch.Tensor] = None, - v_descale: Optional[torch.Tensor] = None, - attention_chunk: int = 0, - softcap: float = 0.0, - num_splits: int = 1, - pack_gqa: Optional[bool] = None, - deterministic: bool = False, - sm_margin: int = 0, -) -> Tuple[torch.Tensor, torch.Tensor]: - # Hardcoded for now because pytorch does not support tuple/int type hints - window_size = (-1, -1) - out, lse, *_ = flash_attn_3_func( - q=q, - k=k, - v=v, - softmax_scale=softmax_scale, - causal=causal, - qv=qv, - q_descale=q_descale, - k_descale=k_descale, - v_descale=v_descale, - window_size=window_size, - attention_chunk=attention_chunk, - softcap=softcap, - num_splits=num_splits, - pack_gqa=pack_gqa, - deterministic=deterministic, - sm_margin=sm_margin, - ) - lse = lse.permute(0, 2, 1) - return out, lse - - -@_register_fake("_diffusers_flash_attn_3::_flash_attn_forward") -def _( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - softmax_scale: Optional[float] = None, - causal: bool = False, - qv: Optional[torch.Tensor] = None, - q_descale: Optional[torch.Tensor] = None, - k_descale: Optional[torch.Tensor] = None, - v_descale: Optional[torch.Tensor] = None, - attention_chunk: int = 0, - softcap: float = 0.0, - num_splits: int = 1, - pack_gqa: Optional[bool] = None, - deterministic: bool = False, - sm_margin: int = 0, -) -> Tuple[torch.Tensor, torch.Tensor]: - window_size = (-1, -1) # noqa: F841 - # A lot of the parameters here are not yet used in any way within diffusers. - # We can safely ignore for now and keep the fake op shape propagation simple. - batch_size, seq_len, num_heads, head_dim = q.shape - lse_shape = (batch_size, seq_len, num_heads) - return torch.empty_like(q), q.new_empty(lse_shape) - - # ===== Helper functions to use attention backends with templated CP autograd functions ===== @@ -1617,14 +1545,19 @@ def _flash_attention_3( return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: - out, lse = _wrapped_flash_attn_3( + out = flash_attn_3_func( q=query, k=key, v=value, softmax_scale=scale, causal=is_causal, + return_attn_probs=return_lse, ) - return (out, lse) if return_lse else out + if return_lse: + # out is (output, lse) tuple when return_attn_probs=True + # lse needs to be permuted from (batch, heads, seq) to (batch, seq, heads) + return out[0], out[1].permute(0, 2, 1) + return out @_AttentionBackendRegistry.register(