From 4b82acc82725ef416aadd9782db74c21b416232e Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Tue, 16 Sep 2025 18:54:48 +0000 Subject: [PATCH 01/19] bench parallel ngram Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- benchmarks/benchmark_ngram_proposer.py | 104 ++++++++++++++++++++++++- tests/v1/spec_decode/test_ngram.py | 18 +++-- vllm/v1/sample/rejection_sampler.py | 2 +- vllm/v1/spec_decode/ngram_proposer.py | 101 ++++++++++++++++++++++-- vllm/v1/worker/gpu_model_runner.py | 26 +++---- 5 files changed, 220 insertions(+), 31 deletions(-) diff --git a/benchmarks/benchmark_ngram_proposer.py b/benchmarks/benchmark_ngram_proposer.py index 11833fa1b3c8..bdcbcaf3f794 100644 --- a/benchmarks/benchmark_ngram_proposer.py +++ b/benchmarks/benchmark_ngram_proposer.py @@ -1,17 +1,31 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import gc +import time +from unittest import mock import numpy as np from tabulate import tabulate from benchmark_utils import TimeCollector -from vllm.config import ModelConfig, SpeculativeConfig, VllmConfig +from vllm.config import ( + CacheConfig, + DeviceConfig, + LoadConfig, + ModelConfig, + ParallelConfig, + SchedulerConfig, + SpeculativeConfig, + VllmConfig, +) +from vllm.platforms import current_platform from vllm.utils import FlexibleArgumentParser from vllm.v1.spec_decode.ngram_proposer import NgramProposer +from vllm.v1.worker.gpu_input_batch import InputBatch +from vllm.v1.worker.gpu_model_runner import GPUModelRunner -def main(args): +def benchmark_propose(args): rows = [] for max_ngram in args.max_ngram: collector = TimeCollector(TimeCollector.US) @@ -69,10 +83,83 @@ def main(args): ) +def benchmark_batched_propose(args): + NUM_SPECULATIVE_TOKENS_NGRAM = 10 + PROMPT_LOOKUP_MIN = 5 + PROMPT_LOOKUP_MAX = 15 + MAX_MODEL_LEN = int(1e7) + DEVICE = current_platform.device_type + + model_config = ModelConfig(model="facebook/opt-125m", runner="generate") + + speculative_config = SpeculativeConfig( + target_model_config=model_config, + target_parallel_config=ParallelConfig(), + method="ngram", + num_speculative_tokens=NUM_SPECULATIVE_TOKENS_NGRAM, + prompt_lookup_max=PROMPT_LOOKUP_MAX, + prompt_lookup_min=PROMPT_LOOKUP_MIN, + ) + + vllm_config = VllmConfig( + model_config=model_config, + cache_config=CacheConfig(), + speculative_config=speculative_config, + device_config=DeviceConfig(device=current_platform.device_type), + parallel_config=ParallelConfig(), + load_config=LoadConfig(), + scheduler_config=SchedulerConfig(), + ) + + # monkey patch vllm.v1.worker.gpu_model_runner.get_pp_group + mock_pp_group = mock.MagicMock() + mock_pp_group.world_size = 1 + with mock.patch( + "vllm.v1.worker.gpu_model_runner.get_pp_group", return_value=mock_pp_group + ): + runner = GPUModelRunner(vllm_config, DEVICE) + + # hack max model len + runner.max_model_len = MAX_MODEL_LEN + runner.drafter_ngram.max_model_len = MAX_MODEL_LEN + + dummy_input_batch = InputBatch( + max_num_reqs=args.num_req, + max_model_len=MAX_MODEL_LEN, + max_num_batched_tokens=args.num_req * args.num_token, + device=DEVICE, + pin_memory=False, + vocab_size=256000, + block_sizes=[16], + ) + dummy_input_batch._req_ids = list(str(id) for id in range(args.num_req)) + dummy_input_batch.spec_decode_unsupported_reqs = [] + dummy_input_batch.num_tokens_no_spec = [args.num_token] * args.num_req + dummy_input_batch.token_ids_cpu = np.random.randint( + 0, 20, (args.num_req, args.num_token) + ) + + runner.input_batch = dummy_input_batch + + sampled_token_ids = [[0]] * args.num_req + + print("Starting benchmark") + # first run is warmup so ignore it + for _ in range(args.num_iteration): + start = time.time() + # runner.propose_ngram_draft_token_ids(sampled_token_ids) + runner.propose_ngram_draft_token_ids_numba(sampled_token_ids) + end = time.time() + print(f"Iteration time (s): {end - start}") + + def invoke_main() -> None: parser = FlexibleArgumentParser( description="Benchmark the performance of N-gram speculative decode drafting" ) + parser.add_argument( + "--batched", action="store_true", help="consider time to prepare batch" + ) # noqa: E501 parser.add_argument( "--num-iteration", type=int, @@ -105,8 +192,17 @@ def invoke_main() -> None: help="Number of speculative tokens to generate", ) args = parser.parse_args() - main(args) + + if not args.batched: + benchmark_propose(args) + else: + benchmark_batched_propose(args) +""" +# Example command lines: +# time python3 benchmarks/benchmark_ngram_proposer.py +# time python3 benchmarks/benchmark_ngram_proposer.py --batched --num-iteration 4 --num-token 1000000 --num-req 128 +""" # noqa: E501 if __name__ == "__main__": - invoke_main() # pragma: no cover + invoke_main() # pragma: no cover \ No newline at end of file diff --git a/tests/v1/spec_decode/test_ngram.py b/tests/v1/spec_decode/test_ngram.py index 4193f4041b32..d11be802ac74 100644 --- a/tests/v1/spec_decode/test_ngram.py +++ b/tests/v1/spec_decode/test_ngram.py @@ -9,11 +9,13 @@ def test_find_longest_matched_ngram_and_propose_tokens(): tokens = np.array([1, 2, 3, 4, 1, 2, 3, 5, 6]) - assert _find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens, - min_ngram=2, - max_ngram=2, - max_model_len=1024, - k=2) is None + result = _find_longest_matched_ngram_and_propose_tokens( + origin_tokens=tokens, + min_ngram=2, + max_ngram=2, + max_model_len=1024, + k=2) + assert result is None or len(result) == 0 tokens = np.array([1, 2, 3, 4, 1, 2, 3]) np.testing.assert_array_equal( @@ -78,13 +80,13 @@ def ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: result = ngram_proposer( min_n=2, max_n=2, k=2).propose(context_token_ids=np.array([1, 2, 3, 4, 5])) - assert result is None + assert result is None or len(result) == 0 # No match for 4-gram. result = ngram_proposer( min_n=4, max_n=4, k=2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3])) - assert result is None + assert result is None or len(result) == 0 # No match for 4-gram but match for 3-gram. result = ngram_proposer( @@ -107,4 +109,4 @@ def ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: result = ngram_proposer( min_n=3, max_n=3, k=2).propose(context_token_ids=np.array( [1, 2, 3, 100, 1, 2, 3, 200, 1, 2, 3, 300, 1, 2, 3])) - assert np.array_equal(result, np.array([100, 1])) + assert np.array_equal(result, np.array([100, 1])) \ No newline at end of file diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 3d5e59addfcf..8db56ef174e8 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -17,7 +17,7 @@ GREEDY_TEMPERATURE: tl.constexpr = -1 # Maximum number of speculative draft tokens allowed per request in a single # step. This value is chosen to be large enough to handle typical use cases. -MAX_SPEC_LEN = 32 +MAX_SPEC_LEN = 128 class RejectionSampler(nn.Module): diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index b92e396d4536..53354df80967 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -1,11 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os from typing import Optional import numpy as np -from numba import jit +from numba import get_num_threads, jit, njit, prange, set_num_threads from vllm.config import VllmConfig +from vllm.logger import init_logger + +logger = init_logger(__name__) class NgramProposer: @@ -22,14 +26,78 @@ def __init__(self, vllm_config: VllmConfig): # Number of tokens follow the match. If there are less than k # tokens follow the match, we will return the maximum amount of # tokens until the end. - self.k = vllm_config.speculative_config.num_speculative_tokens + self.method = vllm_config.speculative_config.method + if self.method == "ngram-eagle": + self.k = vllm_config.speculative_config \ + .num_speculative_tokens_per_method["ngram"] + else: + self.k = vllm_config.speculative_config.num_speculative_tokens # Maximum length of the model. self.max_model_len = vllm_config.model_config.max_model_len + # Pre-allocate buffers for numba batch propose. + max_num_seqs = vllm_config.scheduler_config.max_num_seqs + self.valid_ngram_draft = np.zeros((max_num_seqs, self.k), + dtype=np.int32) + self.valid_ngram_num_drafts = np.zeros((max_num_seqs), dtype=np.int32) + + # Max number of threads for numba parallel processing. + tp_size = vllm_config.parallel_config.tensor_parallel_size + # Divide by 2 to use physical cores + # and not logical cores (hyper-threading). + # Divide by tp_size to ensure each tensor parallel rank + # has some threads since all ranks will run this. + # Ensure at least 1 thread is used. + self.num_numba_thread = max(1, (os.cpu_count() // 2) // tp_size) + # Trigger Numba JIT compilation for N-gram proposer. # This usually takes less than 1 second. self.propose(np.zeros(1024, dtype=np.int32)) + def batch_propose( + self, + num_requests: int, + valid_ngram_requests: list, + num_tokens_no_spec: np.ndarray, + token_ids_cpu: np.ndarray, + ) -> list[list[int]]: + """Batch version of ngram proposer using numba for acceleration. + + Args: + valid_ngram_requests: + Set of indices of requests that need ngram proposals. + num_tokens_no_spec: + Numpy array of shape (batch_size,) representing the number + of tokens without speculative tokens for each request. + token_ids_cpu: + Numpy array of shape (batch_size, max_model_len) + representing the token IDs for each request. + + Returns: + list[list[int]]: + A list where each element is a list of proposed + token IDs for the corresponding request. + """ + original_num_numba_threads = get_num_threads() + set_num_threads(min(self.num_numba_thread, len(valid_ngram_requests))) + + draft_token_ids: list[list[int]] = [] + batch_propose_numba(valid_ngram_requests, num_tokens_no_spec, + token_ids_cpu, self.min_n, self.max_n, + self.max_model_len, self.k, self.valid_ngram_draft, + self.valid_ngram_num_drafts) + + for i in range(num_requests): + if i in valid_ngram_requests and \ + self.valid_ngram_num_drafts[i] > 0: + draft_token_ids.append(self.valid_ngram_draft[ + i, :self.valid_ngram_num_drafts[i]].tolist()) + else: + draft_token_ids.append([]) + + set_num_threads(original_num_numba_threads) + return draft_token_ids + def propose( self, context_token_ids: np.ndarray, @@ -71,6 +139,29 @@ def load_model(self, *args, **kwargs): pass +@njit(parallel=True) +def batch_propose_numba(valid_ngram_requests: list, + num_tokens_no_spec: np.ndarray, + token_ids_cpu: np.ndarray, min_n: int, max_n: int, + max_model_len: int, k: int, + valid_ngram_draft: np.ndarray, + valid_ngram_num_drafts: np.ndarray): + for i in prange(len(valid_ngram_requests)): + idx = valid_ngram_requests[i] + num_tokens = num_tokens_no_spec[idx] + context_token_ids = token_ids_cpu[idx, :num_tokens] + drafter_output = _find_longest_matched_ngram_and_propose_tokens( + origin_tokens=context_token_ids, + min_ngram=min_n, + max_ngram=max_n, + max_model_len=max_model_len, + k=k) + + valid_ngram_num_drafts[i] = drafter_output.shape[0] + if drafter_output is not None: + valid_ngram_draft[i, :drafter_output.shape[0]] = drafter_output + + @jit(nopython=True) def _find_longest_matched_ngram_and_propose_tokens( origin_tokens: np.ndarray, min_ngram: int, max_ngram: int, @@ -84,12 +175,12 @@ def _find_longest_matched_ngram_and_propose_tokens( # Do not generate draft tokens is context is shorter than minimum n-gram total_token = origin_tokens.shape[0] if total_token < min_ngram: - return None + return np.empty((0, ), dtype=origin_tokens.dtype) # Do not generate draft tokens beyond the max model length. k = min(k, max_model_len - total_token) if k <= 0: - return None + return np.empty((0, ), dtype=origin_tokens.dtype) # Flip tokens, and the goal become to find longest ngram # on the rightmost position which matches the prefix with @@ -146,7 +237,7 @@ def _find_longest_matched_ngram_and_propose_tokens( if longest_ngram < min_ngram: # No valid ngram is found - return None + return np.empty((0, ), dtype=origin_tokens.dtype) # Flip the position back, so in origin_tokens, # origin_tokens[total_token-1-position:total_token-1-position+longest_ngram] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 2ae748dee43c..27be657e30de 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2403,35 +2403,35 @@ def propose_ngram_draft_token_ids( self, sampled_token_ids: list[list[int]], ) -> list[list[int]]: - # TODO(woosuk): Optimize. - req_ids = self.input_batch.req_ids - draft_token_ids: list[list[int]] = [] + + # find which requests need ngram proposals + valid_ngram_requests = [] for i, sampled_ids in enumerate(sampled_token_ids): num_sampled_ids = len(sampled_ids) if not num_sampled_ids: # Skip speculative decoding. - draft_token_ids.append([]) continue # Skip requests that require sampling parameters that are not # supported with speculative decoding. - req_id = req_ids[i] + req_id = self.input_batch.req_ids[i] if req_id in self.input_batch.spec_decode_unsupported_reqs: - draft_token_ids.append([]) continue num_tokens = self.input_batch.num_tokens_no_spec[i] if num_tokens >= self.max_model_len: # Skip requests that have already reached the max model length. - draft_token_ids.append([]) continue - drafter_output = self.drafter.propose( - self.input_batch.token_ids_cpu[i, :num_tokens]) - if drafter_output is None or len(drafter_output) == 0: - draft_token_ids.append([]) - else: - draft_token_ids.append(drafter_output.tolist()) + valid_ngram_requests.append(i) + + draft_token_ids = self.drafter_ngram.batch_propose( + len(sampled_token_ids), + valid_ngram_requests, + self.input_batch.num_tokens_no_spec, + self.input_batch.token_ids_cpu, + ) + return draft_token_ids def update_config(self, overrides: dict[str, Any]) -> None: From 877148b4a7d687ed8006df7ea398606fe7221616 Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Tue, 16 Sep 2025 18:57:03 +0000 Subject: [PATCH 02/19] lint Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- benchmarks/benchmark_ngram_proposer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/benchmark_ngram_proposer.py b/benchmarks/benchmark_ngram_proposer.py index bdcbcaf3f794..816c9d98068a 100644 --- a/benchmarks/benchmark_ngram_proposer.py +++ b/benchmarks/benchmark_ngram_proposer.py @@ -205,4 +205,4 @@ def invoke_main() -> None: # time python3 benchmarks/benchmark_ngram_proposer.py --batched --num-iteration 4 --num-token 1000000 --num-req 128 """ # noqa: E501 if __name__ == "__main__": - invoke_main() # pragma: no cover \ No newline at end of file + invoke_main() # pragma: no cover From 72ed508737cf708175370e88d7e8678493503798 Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Tue, 16 Sep 2025 18:57:50 +0000 Subject: [PATCH 03/19] lint Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- tests/v1/spec_decode/test_ngram.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/v1/spec_decode/test_ngram.py b/tests/v1/spec_decode/test_ngram.py index d11be802ac74..b974e44472da 100644 --- a/tests/v1/spec_decode/test_ngram.py +++ b/tests/v1/spec_decode/test_ngram.py @@ -109,4 +109,4 @@ def ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: result = ngram_proposer( min_n=3, max_n=3, k=2).propose(context_token_ids=np.array( [1, 2, 3, 100, 1, 2, 3, 200, 1, 2, 3, 300, 1, 2, 3])) - assert np.array_equal(result, np.array([100, 1])) \ No newline at end of file + assert np.array_equal(result, np.array([100, 1])) From d16950411cc87f594aaa3849a6c3e70033d4504c Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Tue, 16 Sep 2025 19:36:42 +0000 Subject: [PATCH 04/19] fix thread and empty list error numba Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- vllm/v1/spec_decode/ngram_proposer.py | 43 ++++++++++++++++++--------- 1 file changed, 29 insertions(+), 14 deletions(-) diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index 53354df80967..5c4e38e6f1b8 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -47,8 +47,11 @@ def __init__(self, vllm_config: VllmConfig): # and not logical cores (hyper-threading). # Divide by tp_size to ensure each tensor parallel rank # has some threads since all ranks will run this. - # Ensure at least 1 thread is used. - self.num_numba_thread = max(1, (os.cpu_count() // 2) // tp_size) + cpu_count = os.cpu_count() + if cpu_count: + self.num_numba_thread_available = (cpu_count // 2) // tp_size + else: + self.num_numba_thread_available = 1 # Trigger Numba JIT compilation for N-gram proposer. # This usually takes less than 1 second. @@ -78,14 +81,25 @@ def batch_propose( A list where each element is a list of proposed token IDs for the corresponding request. """ - original_num_numba_threads = get_num_threads() - set_num_threads(min(self.num_numba_thread, len(valid_ngram_requests))) - draft_token_ids: list[list[int]] = [] - batch_propose_numba(valid_ngram_requests, num_tokens_no_spec, - token_ids_cpu, self.min_n, self.max_n, - self.max_model_len, self.k, self.valid_ngram_draft, - self.valid_ngram_num_drafts) + + # Only run batch propose if there are requests needing ngram proposals. + # avoid calling numba function with empty list which causes error + # ValueError: cannot compute fingerprint of empty list + if len(valid_ngram_requests): + original_num_numba_threads = get_num_threads() + # Ensure we use at least one thread. + set_num_threads( + max( + 1, + min(self.num_numba_thread_available, + len(valid_ngram_requests)))) + batch_propose_numba(valid_ngram_requests, num_tokens_no_spec, + token_ids_cpu, self.min_n, self.max_n, + self.max_model_len, self.k, + self.valid_ngram_draft, + self.valid_ngram_num_drafts) + set_num_threads(original_num_numba_threads) for i in range(num_requests): if i in valid_ngram_requests and \ @@ -95,7 +109,6 @@ def batch_propose( else: draft_token_ids.append([]) - set_num_threads(original_num_numba_threads) return draft_token_ids def propose( @@ -158,14 +171,16 @@ def batch_propose_numba(valid_ngram_requests: list, k=k) valid_ngram_num_drafts[i] = drafter_output.shape[0] - if drafter_output is not None: + if len(drafter_output): valid_ngram_draft[i, :drafter_output.shape[0]] = drafter_output @jit(nopython=True) -def _find_longest_matched_ngram_and_propose_tokens( - origin_tokens: np.ndarray, min_ngram: int, max_ngram: int, - max_model_len: int, k: int) -> Optional[np.ndarray]: +def _find_longest_matched_ngram_and_propose_tokens(origin_tokens: np.ndarray, + min_ngram: int, + max_ngram: int, + max_model_len: int, + k: int) -> np.ndarray: """ Find the longest n-gram which matches the suffix of the given tokens whose length is within [min_ngram, max_ngram] (inclusive). From 5776965acd0c5b2ecf041f05a7c02a2b52f990ef Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Tue, 16 Sep 2025 19:38:26 +0000 Subject: [PATCH 05/19] fix test none Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- tests/v1/spec_decode/test_ngram.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/v1/spec_decode/test_ngram.py b/tests/v1/spec_decode/test_ngram.py index b974e44472da..481222055c0e 100644 --- a/tests/v1/spec_decode/test_ngram.py +++ b/tests/v1/spec_decode/test_ngram.py @@ -15,7 +15,7 @@ def test_find_longest_matched_ngram_and_propose_tokens(): max_ngram=2, max_model_len=1024, k=2) - assert result is None or len(result) == 0 + assert len(result) == 0 tokens = np.array([1, 2, 3, 4, 1, 2, 3]) np.testing.assert_array_equal( @@ -80,13 +80,13 @@ def ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: result = ngram_proposer( min_n=2, max_n=2, k=2).propose(context_token_ids=np.array([1, 2, 3, 4, 5])) - assert result is None or len(result) == 0 + assert len(result) == 0 # No match for 4-gram. result = ngram_proposer( min_n=4, max_n=4, k=2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3])) - assert result is None or len(result) == 0 + assert len(result) == 0 # No match for 4-gram but match for 3-gram. result = ngram_proposer( From b8d70c0ec02c12dd2cd359153462aaecec390093 Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Tue, 16 Sep 2025 22:00:01 +0000 Subject: [PATCH 06/19] clean and refactor ngram propose Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- vllm/v1/spec_decode/ngram_proposer.py | 75 +++++++++++++++------------ vllm/v1/worker/gpu_model_runner.py | 43 +++------------ 2 files changed, 48 insertions(+), 70 deletions(-) diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index 5c4e38e6f1b8..918adf058cb0 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -55,7 +55,11 @@ def __init__(self, vllm_config: VllmConfig): # Trigger Numba JIT compilation for N-gram proposer. # This usually takes less than 1 second. - self.propose(np.zeros(1024, dtype=np.int32)) + self.propose([[]]*1024, + [""] * 1024, + np.zeros(1024, dtype=np.int32), + np.zeros((1024, self.max_model_len), dtype=np.int32), + set()) def batch_propose( self, @@ -94,11 +98,13 @@ def batch_propose( 1, min(self.num_numba_thread_available, len(valid_ngram_requests)))) + batch_propose_numba(valid_ngram_requests, num_tokens_no_spec, token_ids_cpu, self.min_n, self.max_n, self.max_model_len, self.k, self.valid_ngram_draft, self.valid_ngram_num_drafts) + set_num_threads(original_num_numba_threads) for i in range(num_requests): @@ -113,39 +119,42 @@ def batch_propose( def propose( self, - context_token_ids: np.ndarray, - ) -> Optional[np.ndarray]: - """Proposes the next sequence of tokens based on n-gram pattern - matching in the context. The function finds matches of the last n - tokens in the previous context, and returns k tokens that followed - that match. - - Args: - context_token_ids: Numpy array of token IDs representing the - context sequence. + sampled_token_ids: list[list[int]], + req_ids: list[str], + num_tokens_no_spec: np.ndarray, + token_ids_cpu: np.ndarray, + spec_decode_unsupported_reqs: set, + ) -> list[list[int]]: - Returns: - np.ndarray: The sequence of tokens that followed - the matched n-gram in the context. - None: If no matching n-gram pattern is found. - - Example: - If context_token_ids = [1,2,3,4,2,3], min_n = 2, max_n = 3, and - k = 4: - - The last 3 (= max_n) tokens [4,2,3] cannot find a match. - - The last 2 tokens [2,3] will be matched against the previous - 4 tokens [1,2,3,4]. - - Finding a match of [2,3] would return the tokens that - followed that pattern. Here we will return [4,2,3] because - we only have three tokens after the match. - """ - # TODO(woosuk): Optimize this. - return _find_longest_matched_ngram_and_propose_tokens( - origin_tokens=context_token_ids, - min_ngram=self.min_n, - max_ngram=self.max_n, - max_model_len=self.max_model_len, - k=self.k) + # find which requests need ngram proposals + valid_ngram_requests = [] + for i, sampled_ids in enumerate(sampled_token_ids): + num_sampled_ids = len(sampled_ids) + if not num_sampled_ids: + # Skip speculative decoding. + continue + + # Skip requests that require sampling parameters that are not + # supported with speculative decoding. + req_id = req_ids[i] + if req_id in spec_decode_unsupported_reqs: + continue + + num_tokens = num_tokens_no_spec[i] + if num_tokens >= self.max_model_len: + # Skip requests that have already reached the max model length. + continue + + valid_ngram_requests.append(i) + + draft_token_ids = self.batch_propose( + len(sampled_token_ids), + valid_ngram_requests, + num_tokens_no_spec, + token_ids_cpu, + ) + + return draft_token_ids def load_model(self, *args, **kwargs): # No model to load. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 27be657e30de..07dc1db51a6e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2307,8 +2307,12 @@ def propose_draft_token_ids( num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if self.speculative_config.method == "ngram": assert isinstance(self.drafter, NgramProposer) - draft_token_ids = self.propose_ngram_draft_token_ids( - sampled_token_ids) + draft_token_ids = self.drafter.propose( + sampled_token_ids, self.input_batch.req_ids, + self.input_batch.num_tokens_no_spec, + self.input_batch.token_ids_cpu, + self.input_batch.spec_decode_unsupported_reqs + ) elif self.speculative_config.method == "medusa": assert isinstance(self.drafter, MedusaProposer) if sample_hidden_states.shape[0] == len(sampled_token_ids): @@ -2399,41 +2403,6 @@ def propose_draft_token_ids( ) return draft_token_ids - def propose_ngram_draft_token_ids( - self, - sampled_token_ids: list[list[int]], - ) -> list[list[int]]: - - # find which requests need ngram proposals - valid_ngram_requests = [] - for i, sampled_ids in enumerate(sampled_token_ids): - num_sampled_ids = len(sampled_ids) - if not num_sampled_ids: - # Skip speculative decoding. - continue - - # Skip requests that require sampling parameters that are not - # supported with speculative decoding. - req_id = self.input_batch.req_ids[i] - if req_id in self.input_batch.spec_decode_unsupported_reqs: - continue - - num_tokens = self.input_batch.num_tokens_no_spec[i] - if num_tokens >= self.max_model_len: - # Skip requests that have already reached the max model length. - continue - - valid_ngram_requests.append(i) - - draft_token_ids = self.drafter_ngram.batch_propose( - len(sampled_token_ids), - valid_ngram_requests, - self.input_batch.num_tokens_no_spec, - self.input_batch.token_ids_cpu, - ) - - return draft_token_ids - def update_config(self, overrides: dict[str, Any]) -> None: allowed_config_names = {"load_config", "model_config"} for config_name, config_overrides in overrides.items(): From 07c68bb769503b6617ca3fb3c12238ba13033330 Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Tue, 16 Sep 2025 22:04:10 +0000 Subject: [PATCH 07/19] revert ngram eagle Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- vllm/v1/spec_decode/ngram_proposer.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index 918adf058cb0..90bad0d2aa6e 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -7,9 +7,6 @@ from numba import get_num_threads, jit, njit, prange, set_num_threads from vllm.config import VllmConfig -from vllm.logger import init_logger - -logger = init_logger(__name__) class NgramProposer: @@ -26,12 +23,7 @@ def __init__(self, vllm_config: VllmConfig): # Number of tokens follow the match. If there are less than k # tokens follow the match, we will return the maximum amount of # tokens until the end. - self.method = vllm_config.speculative_config.method - if self.method == "ngram-eagle": - self.k = vllm_config.speculative_config \ - .num_speculative_tokens_per_method["ngram"] - else: - self.k = vllm_config.speculative_config.num_speculative_tokens + self.k = vllm_config.speculative_config.num_speculative_tokens # Maximum length of the model. self.max_model_len = vllm_config.model_config.max_model_len From 8a79c3de1a748c50ea37b591322b452901904759 Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Tue, 16 Sep 2025 23:55:43 +0000 Subject: [PATCH 08/19] add ngram test Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- tests/v1/spec_decode/test_ngram.py | 93 +++++++++++++++++++++++++----- 1 file changed, 80 insertions(+), 13 deletions(-) diff --git a/tests/v1/spec_decode/test_ngram.py b/tests/v1/spec_decode/test_ngram.py index 481222055c0e..38616efc1985 100644 --- a/tests/v1/spec_decode/test_ngram.py +++ b/tests/v1/spec_decode/test_ngram.py @@ -77,36 +77,103 @@ def ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: ))) # No match. + token_ids_cpu=np.array([[1, 2, 3, 4, 5]]) result = ngram_proposer( min_n=2, max_n=2, - k=2).propose(context_token_ids=np.array([1, 2, 3, 4, 5])) - assert len(result) == 0 + k=2).propose( + sampled_token_ids=[[0]], + req_ids=["0"], + num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), + token_ids_cpu=token_ids_cpu, + spec_decode_unsupported_reqs=(), + ) + assert len(result[0]) == 0 # No match for 4-gram. + token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]]) result = ngram_proposer( min_n=4, max_n=4, - k=2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3])) - assert len(result) == 0 + k=2).propose( + sampled_token_ids=[[0]], + req_ids=["0"], + num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), + token_ids_cpu=token_ids_cpu, + spec_decode_unsupported_reqs=(), + ) + assert len(result[0]) == 0 # No match for 4-gram but match for 3-gram. + token_ids_cpu=np.array([[1, 2, 3, 4, 1, 2, 3]]) result = ngram_proposer( min_n=3, max_n=4, - k=2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3])) - assert np.array_equal(result, np.array([4, 1])) + k=2).propose( + sampled_token_ids=[[0]], + req_ids=["0"], + num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), + token_ids_cpu=token_ids_cpu, + spec_decode_unsupported_reqs=(), + ) + assert np.array_equal(result, np.array([[4, 1]])) # Match for both 4-gram and 3-gram. # In this case, the proposer should return the 4-gram match. + token_ids_cpu=np.array([[2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]]) result = ngram_proposer(min_n=3, max_n=4, k=2).propose( - context_token_ids=np.array([2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4])) - assert np.array_equal(result, np.array([1, 2])) # Not [5, 1] + sampled_token_ids=[[0]], + req_ids=["0"], + num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), + token_ids_cpu=token_ids_cpu, + spec_decode_unsupported_reqs=(), + ) + assert np.array_equal(result, np.array([[1, 2]])) # Not [5, 1]] # Match for 2-gram and 3-gram, but not 4-gram. + token_ids_cpu=np.array([[3, 4, 5, 2, 3, 4, 1, 2, 3, 4]]) result = ngram_proposer(min_n=2, max_n=4, k=2).propose( - context_token_ids=np.array([3, 4, 5, 2, 3, 4, 1, 2, 3, 4])) - assert np.array_equal(result, np.array([1, 2])) # Not [5, 2] + sampled_token_ids=[[0]], + req_ids=["0"], + num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), + token_ids_cpu=token_ids_cpu, + spec_decode_unsupported_reqs=(), + ) + assert np.array_equal(result, np.array([[1, 2]])) # Not [5, 2]] # Multiple 3-gram matched, but always pick the first one. + token_ids_cpu=np.array([[1, 2, 3, 100, 1, 2, 3, 200, 1, 2, 3, 300, 1, 2, 3]]) + result = ngram_proposer( + min_n=3, max_n=3, k=2).propose( + sampled_token_ids=[[0]], + req_ids=["0"], + num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), + token_ids_cpu=token_ids_cpu, + spec_decode_unsupported_reqs=(), + ) + assert np.array_equal(result, np.array([[100, 1]])) + + # check empty input + token_ids_cpu = np.array([[]]) + result = ngram_proposer( + min_n=2, max_n=2, k=2).propose( + sampled_token_ids=[[0]], + req_ids=["0"], + num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), + token_ids_cpu=token_ids_cpu, + spec_decode_unsupported_reqs=(), + ) + assert len(result[0]) == 0 + + # check multibatch input + # first request has 5 tokens and a match + # second request has 3 tokens and no match. Padded with -1 for max len 5 + token_ids_cpu = np.array([[1, 2, 3, 1, 2], [4, 5, 6, -1, -1]]) result = ngram_proposer( - min_n=3, max_n=3, k=2).propose(context_token_ids=np.array( - [1, 2, 3, 100, 1, 2, 3, 200, 1, 2, 3, 300, 1, 2, 3])) - assert np.array_equal(result, np.array([100, 1])) + min_n=2, max_n=2, k=2).propose( + sampled_token_ids=[[0], [1]], + req_ids=["0", "1"], + num_tokens_no_spec=np.array([5, 3]), + token_ids_cpu=token_ids_cpu, + spec_decode_unsupported_reqs=(), + ) + assert len(result[0]) == 2 + assert np.array_equal(result[0], np.array([3, 1])) + assert np.array_equal(result[1], np.array([])) \ No newline at end of file From 048ea7c46b9a922d6962c5b88aecc3d3a6b583b4 Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Wed, 17 Sep 2025 00:06:14 +0000 Subject: [PATCH 09/19] update bench Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- benchmarks/benchmark_ngram_proposer.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/benchmarks/benchmark_ngram_proposer.py b/benchmarks/benchmark_ngram_proposer.py index 816c9d98068a..9cd6966ba10f 100644 --- a/benchmarks/benchmark_ngram_proposer.py +++ b/benchmarks/benchmark_ngram_proposer.py @@ -121,7 +121,7 @@ def benchmark_batched_propose(args): # hack max model len runner.max_model_len = MAX_MODEL_LEN - runner.drafter_ngram.max_model_len = MAX_MODEL_LEN + runner.drafter.max_model_len = MAX_MODEL_LEN dummy_input_batch = InputBatch( max_num_reqs=args.num_req, @@ -133,7 +133,7 @@ def benchmark_batched_propose(args): block_sizes=[16], ) dummy_input_batch._req_ids = list(str(id) for id in range(args.num_req)) - dummy_input_batch.spec_decode_unsupported_reqs = [] + dummy_input_batch.spec_decode_unsupported_reqs = () dummy_input_batch.num_tokens_no_spec = [args.num_token] * args.num_req dummy_input_batch.token_ids_cpu = np.random.randint( 0, 20, (args.num_req, args.num_token) @@ -147,8 +147,13 @@ def benchmark_batched_propose(args): # first run is warmup so ignore it for _ in range(args.num_iteration): start = time.time() - # runner.propose_ngram_draft_token_ids(sampled_token_ids) - runner.propose_ngram_draft_token_ids_numba(sampled_token_ids) + runner.drafter.propose( + sampled_token_ids, + dummy_input_batch.req_ids, + dummy_input_batch.num_tokens_no_spec, + dummy_input_batch.token_ids_cpu, + dummy_input_batch.spec_decode_unsupported_reqs + ) end = time.time() print(f"Iteration time (s): {end - start}") From aef7f3c77e2f6496f9f959d406ca4e2ce725b41a Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Wed, 17 Sep 2025 00:08:34 +0000 Subject: [PATCH 10/19] lint Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- benchmarks/benchmark_ngram_proposer.py | 2 +- tests/v1/spec_decode/test_ngram.py | 98 ++++++++++++-------------- vllm/v1/spec_decode/ngram_proposer.py | 13 ++-- vllm/v1/worker/gpu_model_runner.py | 3 +- 4 files changed, 52 insertions(+), 64 deletions(-) diff --git a/benchmarks/benchmark_ngram_proposer.py b/benchmarks/benchmark_ngram_proposer.py index 9cd6966ba10f..d4b83edbd940 100644 --- a/benchmarks/benchmark_ngram_proposer.py +++ b/benchmarks/benchmark_ngram_proposer.py @@ -152,7 +152,7 @@ def benchmark_batched_propose(args): dummy_input_batch.req_ids, dummy_input_batch.num_tokens_no_spec, dummy_input_batch.token_ids_cpu, - dummy_input_batch.spec_decode_unsupported_reqs + dummy_input_batch.spec_decode_unsupported_reqs, ) end = time.time() print(f"Iteration time (s): {end - start}") diff --git a/tests/v1/spec_decode/test_ngram.py b/tests/v1/spec_decode/test_ngram.py index 38616efc1985..b56a705e3997 100644 --- a/tests/v1/spec_decode/test_ngram.py +++ b/tests/v1/spec_decode/test_ngram.py @@ -77,47 +77,41 @@ def ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: ))) # No match. - token_ids_cpu=np.array([[1, 2, 3, 4, 5]]) - result = ngram_proposer( - min_n=2, max_n=2, - k=2).propose( - sampled_token_ids=[[0]], - req_ids=["0"], - num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), - token_ids_cpu=token_ids_cpu, - spec_decode_unsupported_reqs=(), + token_ids_cpu = np.array([[1, 2, 3, 4, 5]]) + result = ngram_proposer(min_n=2, max_n=2, k=2).propose( + sampled_token_ids=[[0]], + req_ids=["0"], + num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), + token_ids_cpu=token_ids_cpu, + spec_decode_unsupported_reqs=(), ) assert len(result[0]) == 0 # No match for 4-gram. token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]]) - result = ngram_proposer( - min_n=4, max_n=4, - k=2).propose( - sampled_token_ids=[[0]], - req_ids=["0"], - num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), - token_ids_cpu=token_ids_cpu, - spec_decode_unsupported_reqs=(), - ) + result = ngram_proposer(min_n=4, max_n=4, k=2).propose( + sampled_token_ids=[[0]], + req_ids=["0"], + num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), + token_ids_cpu=token_ids_cpu, + spec_decode_unsupported_reqs=(), + ) assert len(result[0]) == 0 # No match for 4-gram but match for 3-gram. - token_ids_cpu=np.array([[1, 2, 3, 4, 1, 2, 3]]) - result = ngram_proposer( - min_n=3, max_n=4, - k=2).propose( - sampled_token_ids=[[0]], - req_ids=["0"], - num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), - token_ids_cpu=token_ids_cpu, - spec_decode_unsupported_reqs=(), - ) + token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]]) + result = ngram_proposer(min_n=3, max_n=4, k=2).propose( + sampled_token_ids=[[0]], + req_ids=["0"], + num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), + token_ids_cpu=token_ids_cpu, + spec_decode_unsupported_reqs=(), + ) assert np.array_equal(result, np.array([[4, 1]])) # Match for both 4-gram and 3-gram. # In this case, the proposer should return the 4-gram match. - token_ids_cpu=np.array([[2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]]) + token_ids_cpu = np.array([[2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]]) result = ngram_proposer(min_n=3, max_n=4, k=2).propose( sampled_token_ids=[[0]], req_ids=["0"], @@ -128,7 +122,7 @@ def ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: assert np.array_equal(result, np.array([[1, 2]])) # Not [5, 1]] # Match for 2-gram and 3-gram, but not 4-gram. - token_ids_cpu=np.array([[3, 4, 5, 2, 3, 4, 1, 2, 3, 4]]) + token_ids_cpu = np.array([[3, 4, 5, 2, 3, 4, 1, 2, 3, 4]]) result = ngram_proposer(min_n=2, max_n=4, k=2).propose( sampled_token_ids=[[0]], req_ids=["0"], @@ -139,26 +133,25 @@ def ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: assert np.array_equal(result, np.array([[1, 2]])) # Not [5, 2]] # Multiple 3-gram matched, but always pick the first one. - token_ids_cpu=np.array([[1, 2, 3, 100, 1, 2, 3, 200, 1, 2, 3, 300, 1, 2, 3]]) - result = ngram_proposer( - min_n=3, max_n=3, k=2).propose( - sampled_token_ids=[[0]], - req_ids=["0"], - num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), - token_ids_cpu=token_ids_cpu, - spec_decode_unsupported_reqs=(), + token_ids_cpu = np.array( + [[1, 2, 3, 100, 1, 2, 3, 200, 1, 2, 3, 300, 1, 2, 3]]) + result = ngram_proposer(min_n=3, max_n=3, k=2).propose( + sampled_token_ids=[[0]], + req_ids=["0"], + num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), + token_ids_cpu=token_ids_cpu, + spec_decode_unsupported_reqs=(), ) assert np.array_equal(result, np.array([[100, 1]])) # check empty input token_ids_cpu = np.array([[]]) - result = ngram_proposer( - min_n=2, max_n=2, k=2).propose( - sampled_token_ids=[[0]], - req_ids=["0"], - num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), - token_ids_cpu=token_ids_cpu, - spec_decode_unsupported_reqs=(), + result = ngram_proposer(min_n=2, max_n=2, k=2).propose( + sampled_token_ids=[[0]], + req_ids=["0"], + num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), + token_ids_cpu=token_ids_cpu, + spec_decode_unsupported_reqs=(), ) assert len(result[0]) == 0 @@ -166,14 +159,13 @@ def ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: # first request has 5 tokens and a match # second request has 3 tokens and no match. Padded with -1 for max len 5 token_ids_cpu = np.array([[1, 2, 3, 1, 2], [4, 5, 6, -1, -1]]) - result = ngram_proposer( - min_n=2, max_n=2, k=2).propose( - sampled_token_ids=[[0], [1]], - req_ids=["0", "1"], - num_tokens_no_spec=np.array([5, 3]), - token_ids_cpu=token_ids_cpu, - spec_decode_unsupported_reqs=(), + result = ngram_proposer(min_n=2, max_n=2, k=2).propose( + sampled_token_ids=[[0], [1]], + req_ids=["0", "1"], + num_tokens_no_spec=np.array([5, 3]), + token_ids_cpu=token_ids_cpu, + spec_decode_unsupported_reqs=(), ) assert len(result[0]) == 2 assert np.array_equal(result[0], np.array([3, 1])) - assert np.array_equal(result[1], np.array([])) \ No newline at end of file + assert np.array_equal(result[1], np.array([])) diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index 90bad0d2aa6e..de664ffca265 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os -from typing import Optional import numpy as np from numba import get_num_threads, jit, njit, prange, set_num_threads @@ -47,9 +46,7 @@ def __init__(self, vllm_config: VllmConfig): # Trigger Numba JIT compilation for N-gram proposer. # This usually takes less than 1 second. - self.propose([[]]*1024, - [""] * 1024, - np.zeros(1024, dtype=np.int32), + self.propose([[]] * 1024, [""] * 1024, np.zeros(1024, dtype=np.int32), np.zeros((1024, self.max_model_len), dtype=np.int32), set()) @@ -90,13 +87,13 @@ def batch_propose( 1, min(self.num_numba_thread_available, len(valid_ngram_requests)))) - + batch_propose_numba(valid_ngram_requests, num_tokens_no_spec, token_ids_cpu, self.min_n, self.max_n, self.max_model_len, self.k, self.valid_ngram_draft, self.valid_ngram_num_drafts) - + set_num_threads(original_num_numba_threads) for i in range(num_requests): @@ -136,7 +133,7 @@ def propose( if num_tokens >= self.max_model_len: # Skip requests that have already reached the max model length. continue - + valid_ngram_requests.append(i) draft_token_ids = self.batch_propose( @@ -145,7 +142,7 @@ def propose( num_tokens_no_spec, token_ids_cpu, ) - + return draft_token_ids def load_model(self, *args, **kwargs): diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 07dc1db51a6e..988968f05113 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2311,8 +2311,7 @@ def propose_draft_token_ids( sampled_token_ids, self.input_batch.req_ids, self.input_batch.num_tokens_no_spec, self.input_batch.token_ids_cpu, - self.input_batch.spec_decode_unsupported_reqs - ) + self.input_batch.spec_decode_unsupported_reqs) elif self.speculative_config.method == "medusa": assert isinstance(self.drafter, MedusaProposer) if sample_hidden_states.shape[0] == len(sampled_token_ids): From 9d30e93ce8c8517222fac0e9a4178d06885ccb1c Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Wed, 17 Sep 2025 11:34:26 -0400 Subject: [PATCH 11/19] Update vllm/v1/spec_decode/ngram_proposer.py Co-authored-by: Nick Hill Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- vllm/v1/spec_decode/ngram_proposer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index de664ffca265..b21e5a7d6868 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -79,7 +79,7 @@ def batch_propose( # Only run batch propose if there are requests needing ngram proposals. # avoid calling numba function with empty list which causes error # ValueError: cannot compute fingerprint of empty list - if len(valid_ngram_requests): + if num_ngram_requests := len(valid_ngram_requests): original_num_numba_threads = get_num_threads() # Ensure we use at least one thread. set_num_threads( From 0cd56f2fb154c1bf04908314284a78bcc5a492f2 Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Wed, 17 Sep 2025 15:49:37 +0000 Subject: [PATCH 12/19] lint Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- vllm/v1/spec_decode/ngram_proposer.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index b21e5a7d6868..86f2d0d1fe41 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -83,10 +83,8 @@ def batch_propose( original_num_numba_threads = get_num_threads() # Ensure we use at least one thread. set_num_threads( - max( - 1, - min(self.num_numba_thread_available, - len(valid_ngram_requests)))) + max(1, min(self.num_numba_thread_available, + num_ngram_requests))) batch_propose_numba(valid_ngram_requests, num_tokens_no_spec, token_ids_cpu, self.min_n, self.max_n, From b492dd28dff65050ff7fc19c95741b8d9fbf0718 Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Thu, 18 Sep 2025 15:47:21 +0000 Subject: [PATCH 13/19] add min workload before multi threading Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- vllm/v1/spec_decode/ngram_proposer.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index 86f2d0d1fe41..7f09e6a0fc96 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -82,16 +82,23 @@ def batch_propose( if num_ngram_requests := len(valid_ngram_requests): original_num_numba_threads = get_num_threads() # Ensure we use at least one thread. - set_num_threads( - max(1, min(self.num_numba_thread_available, - num_ngram_requests))) + # If total tokens is small, using multiple threads + # may slow down due to overhead. + total_tokens = np.sum(num_tokens_no_spec) + if total_tokens >= 8192: + set_num_threads( + max(1, min(self.num_numba_thread_available, + num_ngram_requests))) + else: + set_num_threads(1) batch_propose_numba(valid_ngram_requests, num_tokens_no_spec, token_ids_cpu, self.min_n, self.max_n, self.max_model_len, self.k, self.valid_ngram_draft, self.valid_ngram_num_drafts) - + + # Restore original number of threads. set_num_threads(original_num_numba_threads) for i in range(num_requests): From 139ec7a7d46e7db0390f73287dad5538d83cfd0c Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Thu, 18 Sep 2025 16:36:08 +0000 Subject: [PATCH 14/19] lint Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- vllm/v1/spec_decode/ngram_proposer.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index 7f09e6a0fc96..dd9ec1f4aaa1 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -86,9 +86,10 @@ def batch_propose( # may slow down due to overhead. total_tokens = np.sum(num_tokens_no_spec) if total_tokens >= 8192: - set_num_threads( - max(1, min(self.num_numba_thread_available, - num_ngram_requests))) + final_num_threads = max( + 1, min(self.num_numba_thread_available, + num_ngram_requests)) + set_num_threads(final_num_threads) else: set_num_threads(1) @@ -97,7 +98,7 @@ def batch_propose( self.max_model_len, self.k, self.valid_ngram_draft, self.valid_ngram_num_drafts) - + # Restore original number of threads. set_num_threads(original_num_numba_threads) From 602a959a647a1a06f274948acb9042d838cd56e5 Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Thu, 18 Sep 2025 19:12:48 +0000 Subject: [PATCH 15/19] restrict thread Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- vllm/v1/spec_decode/ngram_proposer.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index dd9ec1f4aaa1..4ea827179b9e 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -34,13 +34,17 @@ def __init__(self, vllm_config: VllmConfig): # Max number of threads for numba parallel processing. tp_size = vllm_config.parallel_config.tensor_parallel_size - # Divide by 2 to use physical cores - # and not logical cores (hyper-threading). - # Divide by tp_size to ensure each tensor parallel rank - # has some threads since all ranks will run this. cpu_count = os.cpu_count() if cpu_count: - self.num_numba_thread_available = (cpu_count // 2) // tp_size + # Divide by 2 to use physical cores + # and not logical cores (hyper-threading). + # Cap the number of threads to 8 to avoid using too many threads + # since other components like frontend (incl tokenization) + # and Structured Outputs also use multiple threads. + self.num_numba_thread_available = min(8, (cpu_count // 2)) + # Divide by tp_size to ensure each tensor parallel rank + # has some threads since all ranks will run this. + self.num_numba_thread_available //= tp_size else: self.num_numba_thread_available = 1 From de5ee4f05d65405b69be6e3fb71668431f829161 Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Mon, 22 Sep 2025 16:02:52 +0000 Subject: [PATCH 16/19] disable improvemnt Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- tests/v1/spec_decode/test_ngram.py | 35 ++++++++++++++++++++------- vllm/v1/spec_decode/ngram_proposer.py | 4 ++- 2 files changed, 29 insertions(+), 10 deletions(-) diff --git a/tests/v1/spec_decode/test_ngram.py b/tests/v1/spec_decode/test_ngram.py index b56a705e3997..534bc793bf0f 100644 --- a/tests/v1/spec_decode/test_ngram.py +++ b/tests/v1/spec_decode/test_ngram.py @@ -64,7 +64,7 @@ def test_find_longest_matched_ngram_and_propose_tokens(): def test_ngram_proposer(): - def ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: + def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: # Dummy model config. Just to set max_model_len. model_config = ModelConfig(model="facebook/opt-125m") return NgramProposer( @@ -78,7 +78,7 @@ def ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: # No match. token_ids_cpu = np.array([[1, 2, 3, 4, 5]]) - result = ngram_proposer(min_n=2, max_n=2, k=2).propose( + result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose( sampled_token_ids=[[0]], req_ids=["0"], num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), @@ -89,7 +89,7 @@ def ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: # No match for 4-gram. token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]]) - result = ngram_proposer(min_n=4, max_n=4, k=2).propose( + result = get_ngram_proposer(min_n=4, max_n=4, k=2).propose( sampled_token_ids=[[0]], req_ids=["0"], num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), @@ -100,7 +100,7 @@ def ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: # No match for 4-gram but match for 3-gram. token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]]) - result = ngram_proposer(min_n=3, max_n=4, k=2).propose( + result = get_ngram_proposer(min_n=3, max_n=4, k=2).propose( sampled_token_ids=[[0]], req_ids=["0"], num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), @@ -112,7 +112,7 @@ def ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: # Match for both 4-gram and 3-gram. # In this case, the proposer should return the 4-gram match. token_ids_cpu = np.array([[2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]]) - result = ngram_proposer(min_n=3, max_n=4, k=2).propose( + result = get_ngram_proposer(min_n=3, max_n=4, k=2).propose( sampled_token_ids=[[0]], req_ids=["0"], num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), @@ -123,7 +123,7 @@ def ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: # Match for 2-gram and 3-gram, but not 4-gram. token_ids_cpu = np.array([[3, 4, 5, 2, 3, 4, 1, 2, 3, 4]]) - result = ngram_proposer(min_n=2, max_n=4, k=2).propose( + result = get_ngram_proposer(min_n=2, max_n=4, k=2).propose( sampled_token_ids=[[0]], req_ids=["0"], num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), @@ -135,7 +135,7 @@ def ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: # Multiple 3-gram matched, but always pick the first one. token_ids_cpu = np.array( [[1, 2, 3, 100, 1, 2, 3, 200, 1, 2, 3, 300, 1, 2, 3]]) - result = ngram_proposer(min_n=3, max_n=3, k=2).propose( + result = get_ngram_proposer(min_n=3, max_n=3, k=2).propose( sampled_token_ids=[[0]], req_ids=["0"], num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), @@ -146,7 +146,7 @@ def ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: # check empty input token_ids_cpu = np.array([[]]) - result = ngram_proposer(min_n=2, max_n=2, k=2).propose( + result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose( sampled_token_ids=[[0]], req_ids=["0"], num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), @@ -159,7 +159,7 @@ def ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: # first request has 5 tokens and a match # second request has 3 tokens and no match. Padded with -1 for max len 5 token_ids_cpu = np.array([[1, 2, 3, 1, 2], [4, 5, 6, -1, -1]]) - result = ngram_proposer(min_n=2, max_n=2, k=2).propose( + result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose( sampled_token_ids=[[0], [1]], req_ids=["0", "1"], num_tokens_no_spec=np.array([5, 3]), @@ -169,3 +169,20 @@ def ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: assert len(result[0]) == 2 assert np.array_equal(result[0], np.array([3, 1])) assert np.array_equal(result[1], np.array([])) + + # test 0 thread available: can happen if TP size > CPU count + ngram_proposer = get_ngram_proposer(min_n=2, max_n=2, k=2) + ngram_proposer.num_numba_thread_available = 0 + # using multibatch test + token_ids_cpu = np.array([[1, 2, 3, 1, 2], [4, 5, 6, -1, -1]]) + result = ngram_proposer.propose( + sampled_token_ids=[[0], [1]], + req_ids=["0", "1"], + num_tokens_no_spec=np.array([5, 3]), + token_ids_cpu=token_ids_cpu, + spec_decode_unsupported_reqs=(), + ) + assert len(result[0]) == 2 + assert np.array_equal(result[0], np.array([3, 1])) + assert np.array_equal(result[1], np.array([])) + diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index 4ea827179b9e..b210373b8bf1 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -41,7 +41,9 @@ def __init__(self, vllm_config: VllmConfig): # Cap the number of threads to 8 to avoid using too many threads # since other components like frontend (incl tokenization) # and Structured Outputs also use multiple threads. - self.num_numba_thread_available = min(8, (cpu_count // 2)) + # TODO(ekagra-ranjan): bump up the cap from 1 to 8 + # when TP parallization for ngram is implemented. + self.num_numba_thread_available = min(1, (cpu_count // 2)) # Divide by tp_size to ensure each tensor parallel rank # has some threads since all ranks will run this. self.num_numba_thread_available //= tp_size From f7a2996de5bcd38423c7ebe85990c0c3774b3540 Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Mon, 22 Sep 2025 17:06:25 +0000 Subject: [PATCH 17/19] check 0 thread unittest Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- tests/v1/spec_decode/test_ngram.py | 16 ++++++++++++---- vllm/v1/spec_decode/ngram_proposer.py | 3 ++- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/tests/v1/spec_decode/test_ngram.py b/tests/v1/spec_decode/test_ngram.py index 534bc793bf0f..762bc8d58fd0 100644 --- a/tests/v1/spec_decode/test_ngram.py +++ b/tests/v1/spec_decode/test_ngram.py @@ -170,19 +170,27 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: assert np.array_equal(result[0], np.array([3, 1])) assert np.array_equal(result[1], np.array([])) - # test 0 thread available: can happen if TP size > CPU count + # test if 0 threads available: can happen if TP size > CPU count ngram_proposer = get_ngram_proposer(min_n=2, max_n=2, k=2) ngram_proposer.num_numba_thread_available = 0 + # set max_model_len to 2 * threshold to ensure multithread is used + num_tokens_threshold = ngram_proposer.num_tokens_threshold + ngram_proposer.max_model_len = 2 * num_tokens_threshold # using multibatch test - token_ids_cpu = np.array([[1, 2, 3, 1, 2], [4, 5, 6, -1, -1]]) + middle_integer = num_tokens_threshold // 2 + input_1 = [_ for _ in range(num_tokens_threshold)] + input_1 += [middle_integer, middle_integer+1] + input_2 = [-1] * len(input_1) + input_2[:3] = [4, 5, 6] + token_ids_cpu = np.array([input_1, input_2]) result = ngram_proposer.propose( sampled_token_ids=[[0], [1]], req_ids=["0", "1"], - num_tokens_no_spec=np.array([5, 3]), + num_tokens_no_spec=np.array([len(input_1), 3]), token_ids_cpu=token_ids_cpu, spec_decode_unsupported_reqs=(), ) assert len(result[0]) == 2 - assert np.array_equal(result[0], np.array([3, 1])) + assert np.array_equal(result[0], np.array([middle_integer+2, middle_integer+3])) assert np.array_equal(result[1], np.array([])) diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index b210373b8bf1..8f71a0c25b96 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -33,6 +33,7 @@ def __init__(self, vllm_config: VllmConfig): self.valid_ngram_num_drafts = np.zeros((max_num_seqs), dtype=np.int32) # Max number of threads for numba parallel processing. + self.num_tokens_threshold = 8192 tp_size = vllm_config.parallel_config.tensor_parallel_size cpu_count = os.cpu_count() if cpu_count: @@ -91,7 +92,7 @@ def batch_propose( # If total tokens is small, using multiple threads # may slow down due to overhead. total_tokens = np.sum(num_tokens_no_spec) - if total_tokens >= 8192: + if total_tokens >= self.num_tokens_threshold: final_num_threads = max( 1, min(self.num_numba_thread_available, num_ngram_requests)) From 35fb3447fc972c0d28f8b8d90c4c5bae92c92df2 Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Mon, 22 Sep 2025 17:07:53 +0000 Subject: [PATCH 18/19] comment Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- vllm/v1/spec_decode/ngram_proposer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index 8f71a0c25b96..511f3ded26a5 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -32,10 +32,12 @@ def __init__(self, vllm_config: VllmConfig): dtype=np.int32) self.valid_ngram_num_drafts = np.zeros((max_num_seqs), dtype=np.int32) - # Max number of threads for numba parallel processing. + # Threshold of total number of tokens in the batch to enable + # multi-threading in numba batch propose. self.num_tokens_threshold = 8192 tp_size = vllm_config.parallel_config.tensor_parallel_size cpu_count = os.cpu_count() + # Max number of threads for numba parallel processing. if cpu_count: # Divide by 2 to use physical cores # and not logical cores (hyper-threading). From d9134bdbc475d27be264b46bec92b217469c5a6a Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Mon, 22 Sep 2025 18:12:09 +0000 Subject: [PATCH 19/19] lint Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- tests/v1/spec_decode/test_ngram.py | 8 ++++---- vllm/v1/spec_decode/ngram_proposer.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/v1/spec_decode/test_ngram.py b/tests/v1/spec_decode/test_ngram.py index 762bc8d58fd0..344d19c60db7 100644 --- a/tests/v1/spec_decode/test_ngram.py +++ b/tests/v1/spec_decode/test_ngram.py @@ -178,8 +178,8 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: ngram_proposer.max_model_len = 2 * num_tokens_threshold # using multibatch test middle_integer = num_tokens_threshold // 2 - input_1 = [_ for _ in range(num_tokens_threshold)] - input_1 += [middle_integer, middle_integer+1] + input_1 = [_ for _ in range(num_tokens_threshold)] + input_1 += [middle_integer, middle_integer + 1] input_2 = [-1] * len(input_1) input_2[:3] = [4, 5, 6] token_ids_cpu = np.array([input_1, input_2]) @@ -191,6 +191,6 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: spec_decode_unsupported_reqs=(), ) assert len(result[0]) == 2 - assert np.array_equal(result[0], np.array([middle_integer+2, middle_integer+3])) + assert np.array_equal(result[0], + np.array([middle_integer + 2, middle_integer + 3])) assert np.array_equal(result[1], np.array([])) - diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index 511f3ded26a5..fd8e0a6fd1d2 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -45,7 +45,7 @@ def __init__(self, vllm_config: VllmConfig): # since other components like frontend (incl tokenization) # and Structured Outputs also use multiple threads. # TODO(ekagra-ranjan): bump up the cap from 1 to 8 - # when TP parallization for ngram is implemented. + # when TP parallelization for ngram is implemented. self.num_numba_thread_available = min(1, (cpu_count // 2)) # Divide by tp_size to ensure each tensor parallel rank # has some threads since all ranks will run this.