-
Notifications
You must be signed in to change notification settings - Fork 31.8k
[FA] Generalize fa config checks and fix flags
#43121
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
FA] Generalize fa config checks FA] Generalize fa config checks and fix flags
|
run-slow: pixtral |
|
This comment contains models: ["models/pixtral"] |
CI Results✅ No failing test specific to this PR 🎉 ! |
Cyrilvallez
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very needed indeed, now that we have many flavors!
Even if it's probably alright to keep a small check like this, wdyt about introducing a small helper is_flash_implementation (probably in import_utils.py along is_tracing etc?) or something? Would probably make it clearer/more maintainable if we keep adding flavors!
|
Sounds to me like it's quite important to have a unique entry-point for these cases, as otherwise we have the risk to have very surprising bugs/issues in just a few models because they simply miss some checks |
|
Yes let me change the check to a separate fn, will probably go over more models then (which already do the proper |
|
Moved it to |
Cyrilvallez
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice, thanks a lot! It will truly make it easier!
| if is_flash_attention_requested(self.config) and attention_similarity is not None: | ||
| # Target guided masks are represented as float masks and are incompatible with Flash Attention | ||
| # Fallback to SDPA for this call only so the rest of the model can still benefit from FA | ||
| attention_interface = ALL_ATTENTION_FUNCTIONS["sdpa"] | ||
| logger.warning_once( | ||
| "Falling back to SDPA for target-guided attention because " | ||
| "Flash Attention does not support additive bias masks." | ||
| ) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was not there before, but trusting you on this one!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Discussed this with @yonigozlan internally. This feature seems to be rarely used (it should be None in most cases) + these models often have several different Attention mechanisms; we want to fallback here to support FA partially at least (similarly done in SAM3)
src/transformers/utils/generic.py
Outdated
| Checks whether some flavor of flash attention is requested or not. | ||
| Priority order first goes for any explicitly passed value `requested_attention_implementation` and | ||
| then checks the config's saved value `config._attn_implementation`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe let's just raise if both are provided no? As it does not make sense to give both
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea makes sense, changed the logic and description to force to pass just one not both
|
[For maintainers] Suggested jobs to run (before merge) run-slow: afmoe, altclip, auto, autoformer, bamba, bark, bloom, clipseg, codegen, deepseek_v2, deepseek_v3, edgetam, edgetam_video, ernie4_5_vl_moe, falcon, falcon_h1 |
|
View the CircleCI Test Summary for this PR: https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=43121&sha=6c5129 |
As per title, oftentimes we have somethin along
config._attn_implementation == "flash_attention_2". This is unideal and very faulty as other FA flavors will be ignored, e.g. kernels and FA3. This simply changes this to the generalized version everywhere.Additionally, fixed a few wrong flags that still used
_supports_flash_attn_2(including within tests), relevant commit a231a2bNote that sam2 video, sam3 tracker video, edgetam video do not have any FA tests but they only use a mask with target guided attention which is not used per default, so we add a fallback similar to sam3 to use FA for other attention blocks.
Note that I changed pixtral slightly more due to it doing unnecessary movements at each attention layer. Not sure why the mask was properly done but not the pos ids.