Skip to content

Conversation

@badayvedat
Copy link
Contributor

@badayvedat badayvedat commented Dec 17, 2025

What does this PR do?

The _wrapped_flash_attn_3 function unconditionally unpacks both out and lse from the return value:

out, lse, *_ = flash_attn_3_func(...)

However, it was not passing return_attn_probs=True to request the tuple return. Since Dao-AILab/flash-attention@203b9b3, flash_attn_func returns only out by default, causing:

 File "/root/flash-attention/diffusers/.venv/lib/python3.12/site-packages/torch/_compile.py", line 51, in inner
    return disable_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/flash-attention/diffusers/.venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 838, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/root/flash-attention/diffusers/.venv/lib/python3.12/site-packages/torch/_library/custom_ops.py", line 367, in wrapped_fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/root/flash-attention/diffusers/src/diffusers/models/attention_dispatch.py", line 643, in _wrapped_flash_attn_3
    out, lse, *_ = flash_attn_3_func(
    ^^^^^^^^^^^^
ValueError: not enough values to unpack (expected at least 2, got 1)

How does this pr fixes it

Adds return_attn_probs=True to the flash_attn_3_func call, consistent with how _flash_attention_3_hub handles.

Reproduction

# requirements.txt
git+github.com/huggingface/diffusers@5e48f466b9c0d257f2650e8feec378a0022f2402"
torch==2.7.1
transformers
accelerate
--extra-index-url=https://download.pytorch.org/whl/cu128

and bring your own flash attention build, to repro this i built it from source @ Dao-AILab/flash-attention@ac9b5f1

import torch
from diffusers import FluxPipeline

pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    torch_dtype=torch.bfloat16
).to("cuda")

pipe.transformer.set_attention_backend("_flash_3")

# ValueError: not enough values to unpack (expected at least 2, got 1)
pipe("a photo of a cat", num_inference_steps=1)

Alternative

The wrapper seems to exist to support fa3 as custom op. However, fa3 now has native torch.compile support as of Dao-AILab/flash-attention@c7697bb. This might be making _wrapped_flash_attn_3 redundant, tho i dont really know if that is the only reason.

Before submitting

Who can review?

@badayvedat badayvedat changed the title fix: flash_attn_3_func return value unpacking in _wrapped_flash_attn_3 w torch compile fix: flash_attn_3_func value unpacking in _wrapped_flash_attn_3 w th compile Dec 17, 2025
@badayvedat badayvedat marked this pull request as ready for review December 17, 2025 00:33
@sayakpaul
Copy link
Member

Thanks for your PR. Since torch.compile support has been merged, would you be interested in refactoring and cleaning up the FA3 backend in attention_dispatch.py?

@badayvedat badayvedat changed the title fix: flash_attn_3_func value unpacking in _wrapped_flash_attn_3 w th compile refactor: replace fa3 wrapper with original fa3 in attention backends registry Dec 17, 2025
@badayvedat
Copy link
Contributor Author

Is there any downstream callers of this function that I also need to test?

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Just a single question.

raise


# ===== torch op registrations =====
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we should version-guard this to keep it backwards-compatible?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants