-
-
Notifications
You must be signed in to change notification settings - Fork 11.6k
[V1][Spec Decode] Optimize Rejection Sampler with Triton Kernels #14930
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
LiuXiaoxuanPKU
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Finished the rejection_sampler.py, will continue other files tonight
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
LiuXiaoxuanPKU
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
…m-project#14930) Signed-off-by: Woosuk Kwon <[email protected]> Signed-off-by: Louis Ulmer <[email protected]>
…m-project#14930) Signed-off-by: Woosuk Kwon <[email protected]>
…m-project#14930) Signed-off-by: Woosuk Kwon <[email protected]> Signed-off-by: Mu Huai <[email protected]>
| GREEDY_TEMPERATURE: tl.constexpr = -1 | ||
| # Maximum number of speculative draft tokens allowed per request in a single | ||
| # step. This value is chosen to be large enough to handle typical use cases. | ||
| MAX_SPEC_LEN = 32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @WoosukKwon , is there any limitation MAX_SPEC_LEN should be 32? Can it be larger? Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mmyxym There's no blocker to make it 64. Everything should work if you just change the number. I just thought 32 would be enough for all practical use cases.
This PR optimizes the rejection sampler in #13933 with custom Triton kernels.
By using the Triton kernels, the PR brings the following benefits:
[num_tokens, vocab_size]for the logits tensors, instead of[batch_size, max_spec_len, vocab_size]. This reduces the GPU memory usage a lot.cat,gather, etc.)Performance benchmark: Llama 3.1 8B, ShareGPT, 1xH100, temperature 0.1
SD config:
--speculative-model "[ngram]" --ngram_prompt_lookup_min 5 --ngram-prompt-lookup-max 5 --num_speculative_tokens 325% throughput increase compared to main w/o SD, and 18% increase compared to main w/ SD.
Accuracy benchmark: GSM8K, Llama 3.1 8B Instruct, 5 shots