Skip to content

Conversation

@WoosukKwon
Copy link
Collaborator

@WoosukKwon WoosukKwon commented Mar 17, 2025

This PR optimizes the rejection sampler in #13933 with custom Triton kernels.

By using the Triton kernels, the PR brings the following benefits:

  1. Now we use the flattened shape [num_tokens, vocab_size] for the logits tensors, instead of [batch_size, max_spec_len, vocab_size]. This reduces the GPU memory usage a lot.
  2. Zero synchronization between CPU and GPU.
  3. Remove inefficient data movement (i.e., a bunch of cat, gather, etc.)
  4. (Arguably) easier-to-read code

Performance benchmark: Llama 3.1 8B, ShareGPT, 1xH100, temperature 0.1

SD config: --speculative-model "[ngram]" --ngram_prompt_lookup_min 5 --ngram-prompt-lookup-max 5 --num_speculative_tokens 3

Screenshot 2025-03-17 at 11 28 14 AM
Throughput (reqs/s)
main (w/o SD) 51.49
main (w/ SD) 54.41
This PR (w/ SD) 64.16

25% throughput increase compared to main w/o SD, and 18% increase compared to main w/ SD.

Accuracy benchmark: GSM8K, Llama 3.1 8B Instruct, 5 shots

Temperature Exact match
w/o SD 0.0 75.7
1.0 50.9
w/ SD 0.0 75.9
1.0 51.8

Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Copy link
Collaborator

@LiuXiaoxuanPKU LiuXiaoxuanPKU left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Finished the rejection_sampler.py, will continue other files tonight

Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Copy link
Collaborator

@LiuXiaoxuanPKU LiuXiaoxuanPKU left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@WoosukKwon WoosukKwon merged commit 99abb8b into main Mar 18, 2025
29 of 32 checks passed
@WoosukKwon WoosukKwon deleted the v1-opt-rej branch March 18, 2025 21:31
lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
shreyankg pushed a commit to shreyankg/vllm that referenced this pull request May 3, 2025
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
GREEDY_TEMPERATURE: tl.constexpr = -1
# Maximum number of speculative draft tokens allowed per request in a single
# step. This value is chosen to be large enough to handle typical use cases.
MAX_SPEC_LEN = 32
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @WoosukKwon , is there any limitation MAX_SPEC_LEN should be 32? Can it be larger? Thanks.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mmyxym There's no blocker to make it 64. Everything should work if you just change the number. I just thought 32 would be enough for all practical use cases.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed speculative-decoding v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants