Skip to content

Conversation

@ekagra-ranjan
Copy link
Contributor

@ekagra-ranjan ekagra-ranjan commented Sep 16, 2025

Purpose

Ngram overhead increases with batch size and seq len. This is CPU overhead and adds to the critical path of inference and becomes more imp if seq len or bs increases. This PR parallelizes the CPU compute along batch dimension. The overhead reduces upto 8x. The threads are capped at 8 since there are other processes like frontend (tokenizaton and req handling) and structured output that need multithreading.

This PR also cleans the propose function of ngram. Unlike other SD methods like eagle and medusa, the ngram in vLLM has few draft related functions in model runner. This PR refactors it such that the interface is closer to other SD method and abstracts it.

Benchmark (overhead in ms)

cmd: time python3 benchmarks/benchmark_ngram_proposer.py --batched --num-iteration 4 --num-token 1000000 --num-req 128

Test

pytest -sv tests/v1/spec_decode/test_ngram.py
pytest -sv tests/v1/e2e/test_spec_decode.py::test_ngram_correctness

AL
time VLLM_USE_V1=1 python3 examples/offline_inference/spec_decode.py --method ngram --model-dir meta-llama/Llama-3.1-8B-Instruct --prompt_lookup_min 2 --prompt_lookup_max 5 --num_spec_tokens 5 --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --print-output

Output w/o PR

# normal
total_num_output_tokens: 17005
num_drafts: 2524
num_draft_tokens: 12589
num_accepted_tokens: 2592
mean acceptance length: 2.03
--------------------------------------------------
acceptance at token 0: 0.45
acceptance at token 1: 0.26
acceptance at token 2: 0.15
acceptance at token 3: 0.10
acceptance at token 4: 0.07

Output with PR

num_draft_tokens: 12753
num_accepted_tokens: 2648  
mean acceptance length: 2.04      
--------------------------------------------------   
acceptance at token 0: 0.44
acceptance at token 1: 0.26    
acceptance at token 2: 0.16    
acceptance at token 3: 0.11    
acceptance at token 4: 0.07 

Signed-off-by: Ekagra Ranjan <[email protected]>
@mergify mergify bot added performance Performance-related issues speculative-decoding v1 labels Sep 16, 2025
Signed-off-by: Ekagra Ranjan <[email protected]>
Signed-off-by: Ekagra Ranjan <[email protected]>
@ekagra-ranjan ekagra-ranjan force-pushed the er-ngram-batch-parallel-2 branch from 789a8e7 to 880bb72 Compare September 16, 2025 19:38
@ekagra-ranjan ekagra-ranjan force-pushed the er-ngram-batch-parallel-2 branch from ad5049a to b8d70c0 Compare September 16, 2025 22:01
Signed-off-by: Ekagra Ranjan <[email protected]>
Signed-off-by: Ekagra Ranjan <[email protected]>
Signed-off-by: Ekagra Ranjan <[email protected]>
Signed-off-by: Ekagra Ranjan <[email protected]>
@ekagra-ranjan ekagra-ranjan force-pushed the er-ngram-batch-parallel-2 branch from 621975d to aef7f3c Compare September 17, 2025 00:09
@ekagra-ranjan ekagra-ranjan changed the title bench parallel ngram [Spec Decode] Add Batch Parallel Ngram Sep 17, 2025
@ekagra-ranjan ekagra-ranjan marked this pull request as ready for review September 17, 2025 00:23
@ekagra-ranjan ekagra-ranjan changed the title [Spec Decode] Add Batch Parallel Ngram [Spec Decode] Add Batch Parallel Ngram. Upto 30x lower overhead. Sep 17, 2025
# has some threads since all ranks will run this.
cpu_count = os.cpu_count()
if cpu_count:
self.num_numba_thread_available = (cpu_count // 2) // tp_size
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may want to be more conservative here. There may be be more worker processes for other reasons than TP (PP, DP, ...). Also there is a front-end process which may be using more than one thread for tokenization.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point!

  1. Does vLLM components assume hyperthreading or use physical cores? If they use hyperthreading then the current ngram estimate is reserving threads only on physical core which leaves a lot of room for other processes.
  2. How do we estimate the num of threads for front end? Should we subtract lets say 5 from cpu_count before dividing it among TP?
  3. If its DP, then batch size per device would be smaller so the threads being reserved for ngram would also go down since the final thread is min(self.num_numba_thread_available, len(valid_ngram_requests)). So we dont need to split it on DP unlike TP?
  4. If its PP, then only the last rank will be doing this operation so we dont need to split it on PP unlike TP?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Here's a reference of parallelization code I wrote for structured outputs. This is on the frontend, so not quite the same, but similar logic might apply. One trick I used was to only dispatch the work to threads when the batch size is high. If only one or two reqs need ngram, you might be able to save some overhead and just run it directly.
  • Why does this need to happen on each TP rank? Could it instead be dispatched only by the driver worker, for example? If it can be done this way, is it then possible to partition the work over the TP workers directly, instead of having to spin up new threads?

Copy link
Contributor Author

@ekagra-ranjan ekagra-ranjan Sep 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for sharing the reference! I will have a look at it.

Why does this need to happen on each TP rank?

It doesn't need to but the current implementation does not split the ngram drafting across different TP. The same CPU op is run on all rank by default. This PR only adds Batch parallelization. TP parallel is something I was thinking of but I stopped since the overhead of communication and synchronization could outweight the benefit. The most optimal approach would be do TP parallel if the workload is above some threshold and would need additional benchmarking to figure out that threshold. I am leaving it for later when that becomes a bottleneck.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's a reference of parallelization code I wrote for structured outputs.

Curious how you came up with the number 8?

This is on the frontend, so not quite the same, but similar logic might apply.

The frontend and the backend would be sharing the same thread pool, right? Would being on backend make the max thread greater than, lower than or equal to 8?

If only one or two reqs need ngram, you might be able to save some overhead and just run it directly.

good idea. I will check the min req needed to see gain using threads and update it.

Copy link
Contributor Author

@ekagra-ranjan ekagra-ranjan Sep 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@benchislett - I have updated the thread count to follow what you did earlier. Please have a look. I am still curious about the above questions so looking fwd to hearing your thoughts.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious how you came up with the number 8?

I just measured the perf for this use-case and found that the performance was stagnating after 4 threads, and was never any better with >8 threads. Presumably the sharing and synchronization overhead was too large for >8 threads. I even used batched messages to reduce the overhead when issuing tasks to the threads.

The frontend and the backend would be sharing the same thread pool, right? Would being on backend make the max thread greater than, lower than or equal to 8?

I just mean that here, each TP worker spawns it own thread pool. This is not the case for structured outputs, since it's owned by the scheduler process.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ekagra-ranjan my main issue here is that this seems like an inefficient usage of system resources.

As I understand, the justification for running the (serial) drafting on each TP worker is that it will take about the same amount of time to run it on one worker compared to running it on all workers. Extending this logic, I can see how you might add a couple parallel threads per worker to accelerate. But at this stage, it becomes a better use of system resources to have only the driver worker do the ngram drafting, and have that worker spin up tp_size more threads. That way, the batch is effectively split up by a greater factor while using the same amount of CPU power.

Does this match your understanding of the system, or am I misunderstanding how you are orchestrating the threads?

Copy link
Contributor Author

@ekagra-ranjan ekagra-ranjan Sep 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand what you are suggesting and I agree that TP level parallelization will increase the speedup. There are 2 axis of parallelization over batch: thread and TP.

Lets say TP is 4. Before this PR, each TP was doing the same CPU work of sequentially processing the batch on their own thread. Total thread in the system being used is 4 for ngram drafting.

This PR only implements the thread level parallelization. It set the total max threads over the system (in this case 8) to be allocated for ngram draft. It then divides it among the TP workers so each worker gets 2 threads. The total thread being used is still 8, but the individual worker gets to use 2 workers so the speedup in this case is 2x.

When TP parallelization would be implemented, then batch will further be split across TP workers and in this example the speedup will increase to 8x.

The current setup is still wasting cpu cycles by not splitting the work across TP and doing the same work across all TP. This PR does not change that. I can do the TP parallel that in another PR. Alternatively, I can reduce the number 8 to 4 and bring it back up to 8 when in the followup TP PR.

Wdyt?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@benchislett - I have disabled the threading in this PR. I will enable it in the next PR with TP parallelization. The goal of this PR now is to add the code changes needed before fully unlocking the ngram parallelisation with TP. Pls review/approve.

ekagra-ranjan and others added 2 commits September 17, 2025 11:34
Co-authored-by: Nick Hill <[email protected]>
Signed-off-by: Ekagra Ranjan <[email protected]>
Signed-off-by: Ekagra Ranjan <[email protected]>
Signed-off-by: Ekagra Ranjan <[email protected]>
Signed-off-by: Ekagra Ranjan <[email protected]>
@ekagra-ranjan ekagra-ranjan changed the title [Spec Decode] Add Batch Parallel Ngram. Upto 30x lower overhead. [Spec Decode] Add Batch Parallel Ngram. Upto 8 lower overhead. Sep 18, 2025
@ekagra-ranjan ekagra-ranjan changed the title [Spec Decode] Add Batch Parallel Ngram. Upto 8 lower overhead. [Spec Decode] Add Batch Parallel Ngram. Upto 8x lower overhead. Sep 18, 2025
@WoosukKwon WoosukKwon added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 23, 2025
Copy link
Collaborator

@benchislett benchislett left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. I still have concerns about how we handle TP parallelism, but this is an okay baseline. Is it expected to provide any speedup, now that parallelism is disabled?

@ekagra-ranjan
Copy link
Contributor Author

I still have concerns about how we handle TP parallelism

this PR doesnt handle it. It will come in the next PR

Is it expected to provide any speedup, now that parallelism is disabled?

No, it wont. I will enable it back in the subsequent PR which will enable TP. My understanding is that you concerned with spinning extra threads with more TP so I will enable threading when the TP is handled in next PR.

@ywang96 ywang96 merged commit e71b8e2 into vllm-project:main Sep 25, 2025
51 checks passed
@taohui
Copy link
Contributor

taohui commented Sep 26, 2025

The docstring in ngram_proposer.py has a formatting issue: the num_tokens_no_spec argument description is mis-indented, causing griffe to warn about a missing name: description pair.

yewentao256 pushed a commit that referenced this pull request Oct 3, 2025
Signed-off-by: Ekagra Ranjan <[email protected]>
Co-authored-by: Nick Hill <[email protected]>
Signed-off-by: yewentao256 <[email protected]>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
choprahetarth pushed a commit to Tandemn-Labs/vllm that referenced this pull request Oct 11, 2025
lywa1998 pushed a commit to lywa1998/vllm that referenced this pull request Oct 20, 2025
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
rtourgeman pushed a commit to rtourgeman/vllm that referenced this pull request Nov 10, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed speculative-decoding v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants