-
-
Notifications
You must be signed in to change notification settings - Fork 11.7k
[WIP] [Speculative Decoding] Use MQA kernel for target model verification #5691
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Some preliminary benchmark results of reducing the scoring time.
|
|
This PR is based on #6052, otherwise, it's hard to add the cuda graph support. |
|
The MQAScorer cannot handle the case when |
|
closed as moved to #8839 |
|
Hello, I wonder whether MQA in 'MQA kernel' is related to MQA in 'MHA, GQA, MQA'?MQA confuses me a lot. Thanks |
This PR is the first attempt to use MQA kernel for target model verification, therefore we can remove the overhead of batch expansion. Currently it uses flash attention
flash_attn_varlen_funcfor verification. We can use flashinfer as the next step. The tricky part is to add cuda graph support.Some difficulties:
_num_computed_tokens, which is a field used by chunked prefill and this needs more attention here.TODO:
RFC for this PR