-
Notifications
You must be signed in to change notification settings - Fork 31.1k
Fix GPT-2 Flash Attention 2 generation with left-padding #41966
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix GPT-2 Flash Attention 2 generation with left-padding #41966
Conversation
2877ab9 to
4f20a1f
Compare
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
@vasqu , ready for review :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, fa should know that it's indeed causal! cc @vasqu here as well!
However, I think we can even completely remove the shape check, as sdpa checks it internally to make sure of the alignment, and fa should align the mask correctly already if I remember correctly.
vasqu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm with Cyril here, let's simplify the forward, essentially removing the whole is_causal logic during the forward and only change it within in the init based on if we are cross attn or not
|
All the is_causal logic is correctly handled by the integrations; we just need to make sure we define it correctly on init |
|
Makes sense. Thanks @vasqu, @Cyrilvallez. Pushing the changes now. |
015bba0 to
2edb434
Compare
|
[For maintainers] Suggested jobs to run (before merge) run-slow: decision_transformer, gpt2 |
|
run-slow: decision_transformer, gpt2 |
|
This comment contains models: ["models/decision_transformer", "models/gpt2"] |
|
run-slow: decision_transformer, gpt2 |
CI Results |
|
This comment contains models: ["models/decision_transformer", "models/gpt2"] |
CI Results✅ No failing test specific to this PR 🎉 ! |
vasqu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx a lot, looking good now!
|
@Abdennacer-Badaoui can you run the fa tests locally and confirm that those work? Seems like our CI doesn't have FA installed anymore so just wanna double check |
|
@vasqu, I tested the FA failing tests for which I opened the PR. they’re all passing. |
|
Thx a lot ❤️ |
What does this PR do?
Fixes Flash Attention 2 generation with left-padding in GPT-2 by ensuring
is_causal=Trueis set even when padding masks are provided.Flash Attention 2 handles causal masking and padding independently, but the original code incorrectly disabled causal masking when any attention mask was present.
Fixes:
test_flash_attn_2_generate_padding_left