-
-
Notifications
You must be signed in to change notification settings - Fork 11.8k
[Attention] Fix FlashMLA metadata builder arguments for q_len > 1 #27368
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Attention] Fix FlashMLA metadata builder arguments for q_len > 1 #27368
Conversation
Signed-off-by: Matthew Bonanni <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request correctly fixes a bug in the FlashMLA metadata builder for decode scenarios with a query length greater than one. The change properly calculates num_q_tokens_per_head_k and passes it to get_mla_metadata, which resolves the performance degradation and crashes noted in the description. The provided benchmarks clearly demonstrate the significant speedup achieved by this fix. The implementation is correct and well-targeted. Overall, this is an excellent and important bug fix.
mgoin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there an eval we can run to validate this? I assume we could do deepseek with mtp enabled
LucasWilkinson
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM; thanks for tracking this down!
nit: can you make a small note that we use the max but all the query lens should be the same
Signed-off-by: Matthew Bonanni <[email protected]>
|
@mgoin will do! |
LucasWilkinson
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM (assuming evals path; dont merge till then; but I dont see any reason the wont)
…lm-project#27368) Signed-off-by: Matthew Bonanni <[email protected]> Signed-off-by: Alberto Perdomo <[email protected]>
|
@mgoin @LucasWilkinson confirmed evals look good: |
…lm-project#27368) Signed-off-by: Matthew Bonanni <[email protected]>
…lm-project#27368) Signed-off-by: Matthew Bonanni <[email protected]> Signed-off-by: 0xrushi <[email protected]>
…lm-project#27368) Signed-off-by: Matthew Bonanni <[email protected]> Signed-off-by: 0xrushi <[email protected]>
…lm-project#27368) Signed-off-by: Matthew Bonanni <[email protected]>
…lm-project#27368) Signed-off-by: Matthew Bonanni <[email protected]>
…lm-project#27368) Signed-off-by: Matthew Bonanni <[email protected]>
Purpose
As of #26541, FlashMLA now supports
q_len > 1in the decode pipeline. Theget_mla_metadatacall was not updated, however, leading to poor performance (and potentially, crashes) in these cases. This PR is a simple bug fix achieving a substantial speedup, especially at small batch sizes.Note: uses the benchmarks in #26835 (not yet merged)
cc @LucasWilkinson
Test Plan
python benchmarks/attention_benchmarks/benchmark.py --config benchmarks/attention_benchmarks/configs/flashmla_bugfix_demo.yamlTest Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.