Skip to content

Conversation

@LiuXiaoxuanPKU
Copy link
Collaborator

@LiuXiaoxuanPKU LiuXiaoxuanPKU commented Jun 19, 2024

This PR is the first attempt to use MQA kernel for target model verification, therefore we can remove the overhead of batch expansion. Currently it uses flash attention flash_attn_varlen_func for verification. We can use flashinfer as the next step. The tricky part is to add cuda graph support.

Some difficulties:

  1. We are modifying the _num_computed_tokens, which is a field used by chunked prefill and this needs more attention here.
  2. The current implementation does not handle spec/non-spec requests within the same batch. We assume all requests within the same batch perform speculative decoding.

TODO:

  • Pass simple TP=1 tests.
  • Add cuda graph support.

RFC for this PR

@LiuXiaoxuanPKU LiuXiaoxuanPKU marked this pull request as draft June 19, 2024 19:04
@LiuXiaoxuanPKU
Copy link
Collaborator Author

LiuXiaoxuanPKU commented Jul 9, 2024

Some preliminary benchmark results of reducing the scoring time.
All numbers (in ms) are measured with cuda graph support, it's the scoring time for llama-7B model on a single A100.

Batch size num_speculative_token MQAScorer (ms) BatchExpansionTop1Scorer (ms)
4 5 13.7 15.4
8 5 14.8 17.7
16 5 18.4 23.9
32 5 25.5 36.4
64 5 35.1 54.5
128 5 70.9 96.2

@LiuXiaoxuanPKU
Copy link
Collaborator Author

This PR is based on #6052, otherwise, it's hard to add the cuda graph support.

@jjjjohnson
Copy link

jjjjohnson commented Jul 18, 2024

The MQAScorer cannot handle the case when proposals.proposal_lens has 0 element in it... It happens when NGramWorker failed to match any token

@LiuXiaoxuanPKU
Copy link
Collaborator Author

closed as moved to #8839

@JaviS-Rei
Copy link

Hello, I wonder whether MQA in 'MQA kernel' is related to MQA in 'MHA, GQA, MQA'?MQA confuses me a lot. Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants