diff --git a/benchmarks/benchmark_ngram_proposer.py b/benchmarks/benchmark_ngram_proposer.py index 11833fa1b3c8..d4b83edbd940 100644 --- a/benchmarks/benchmark_ngram_proposer.py +++ b/benchmarks/benchmark_ngram_proposer.py @@ -1,17 +1,31 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import gc +import time +from unittest import mock import numpy as np from tabulate import tabulate from benchmark_utils import TimeCollector -from vllm.config import ModelConfig, SpeculativeConfig, VllmConfig +from vllm.config import ( + CacheConfig, + DeviceConfig, + LoadConfig, + ModelConfig, + ParallelConfig, + SchedulerConfig, + SpeculativeConfig, + VllmConfig, +) +from vllm.platforms import current_platform from vllm.utils import FlexibleArgumentParser from vllm.v1.spec_decode.ngram_proposer import NgramProposer +from vllm.v1.worker.gpu_input_batch import InputBatch +from vllm.v1.worker.gpu_model_runner import GPUModelRunner -def main(args): +def benchmark_propose(args): rows = [] for max_ngram in args.max_ngram: collector = TimeCollector(TimeCollector.US) @@ -69,10 +83,88 @@ def main(args): ) +def benchmark_batched_propose(args): + NUM_SPECULATIVE_TOKENS_NGRAM = 10 + PROMPT_LOOKUP_MIN = 5 + PROMPT_LOOKUP_MAX = 15 + MAX_MODEL_LEN = int(1e7) + DEVICE = current_platform.device_type + + model_config = ModelConfig(model="facebook/opt-125m", runner="generate") + + speculative_config = SpeculativeConfig( + target_model_config=model_config, + target_parallel_config=ParallelConfig(), + method="ngram", + num_speculative_tokens=NUM_SPECULATIVE_TOKENS_NGRAM, + prompt_lookup_max=PROMPT_LOOKUP_MAX, + prompt_lookup_min=PROMPT_LOOKUP_MIN, + ) + + vllm_config = VllmConfig( + model_config=model_config, + cache_config=CacheConfig(), + speculative_config=speculative_config, + device_config=DeviceConfig(device=current_platform.device_type), + parallel_config=ParallelConfig(), + load_config=LoadConfig(), + scheduler_config=SchedulerConfig(), + ) + + # monkey patch vllm.v1.worker.gpu_model_runner.get_pp_group + mock_pp_group = mock.MagicMock() + mock_pp_group.world_size = 1 + with mock.patch( + "vllm.v1.worker.gpu_model_runner.get_pp_group", return_value=mock_pp_group + ): + runner = GPUModelRunner(vllm_config, DEVICE) + + # hack max model len + runner.max_model_len = MAX_MODEL_LEN + runner.drafter.max_model_len = MAX_MODEL_LEN + + dummy_input_batch = InputBatch( + max_num_reqs=args.num_req, + max_model_len=MAX_MODEL_LEN, + max_num_batched_tokens=args.num_req * args.num_token, + device=DEVICE, + pin_memory=False, + vocab_size=256000, + block_sizes=[16], + ) + dummy_input_batch._req_ids = list(str(id) for id in range(args.num_req)) + dummy_input_batch.spec_decode_unsupported_reqs = () + dummy_input_batch.num_tokens_no_spec = [args.num_token] * args.num_req + dummy_input_batch.token_ids_cpu = np.random.randint( + 0, 20, (args.num_req, args.num_token) + ) + + runner.input_batch = dummy_input_batch + + sampled_token_ids = [[0]] * args.num_req + + print("Starting benchmark") + # first run is warmup so ignore it + for _ in range(args.num_iteration): + start = time.time() + runner.drafter.propose( + sampled_token_ids, + dummy_input_batch.req_ids, + dummy_input_batch.num_tokens_no_spec, + dummy_input_batch.token_ids_cpu, + dummy_input_batch.spec_decode_unsupported_reqs, + ) + end = time.time() + print(f"Iteration time (s): {end - start}") + + def invoke_main() -> None: parser = FlexibleArgumentParser( description="Benchmark the performance of N-gram speculative decode drafting" ) + parser.add_argument( + "--batched", action="store_true", help="consider time to prepare batch" + ) # noqa: E501 parser.add_argument( "--num-iteration", type=int, @@ -105,8 +197,17 @@ def invoke_main() -> None: help="Number of speculative tokens to generate", ) args = parser.parse_args() - main(args) + + if not args.batched: + benchmark_propose(args) + else: + benchmark_batched_propose(args) +""" +# Example command lines: +# time python3 benchmarks/benchmark_ngram_proposer.py +# time python3 benchmarks/benchmark_ngram_proposer.py --batched --num-iteration 4 --num-token 1000000 --num-req 128 +""" # noqa: E501 if __name__ == "__main__": invoke_main() # pragma: no cover diff --git a/tests/v1/spec_decode/test_ngram.py b/tests/v1/spec_decode/test_ngram.py index 4193f4041b32..344d19c60db7 100644 --- a/tests/v1/spec_decode/test_ngram.py +++ b/tests/v1/spec_decode/test_ngram.py @@ -9,11 +9,13 @@ def test_find_longest_matched_ngram_and_propose_tokens(): tokens = np.array([1, 2, 3, 4, 1, 2, 3, 5, 6]) - assert _find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens, - min_ngram=2, - max_ngram=2, - max_model_len=1024, - k=2) is None + result = _find_longest_matched_ngram_and_propose_tokens( + origin_tokens=tokens, + min_ngram=2, + max_ngram=2, + max_model_len=1024, + k=2) + assert len(result) == 0 tokens = np.array([1, 2, 3, 4, 1, 2, 3]) np.testing.assert_array_equal( @@ -62,7 +64,7 @@ def test_find_longest_matched_ngram_and_propose_tokens(): def test_ngram_proposer(): - def ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: + def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: # Dummy model config. Just to set max_model_len. model_config = ModelConfig(model="facebook/opt-125m") return NgramProposer( @@ -75,36 +77,120 @@ def ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: ))) # No match. - result = ngram_proposer( - min_n=2, max_n=2, - k=2).propose(context_token_ids=np.array([1, 2, 3, 4, 5])) - assert result is None + token_ids_cpu = np.array([[1, 2, 3, 4, 5]]) + result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose( + sampled_token_ids=[[0]], + req_ids=["0"], + num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), + token_ids_cpu=token_ids_cpu, + spec_decode_unsupported_reqs=(), + ) + assert len(result[0]) == 0 # No match for 4-gram. - result = ngram_proposer( - min_n=4, max_n=4, - k=2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3])) - assert result is None + token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]]) + result = get_ngram_proposer(min_n=4, max_n=4, k=2).propose( + sampled_token_ids=[[0]], + req_ids=["0"], + num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), + token_ids_cpu=token_ids_cpu, + spec_decode_unsupported_reqs=(), + ) + assert len(result[0]) == 0 # No match for 4-gram but match for 3-gram. - result = ngram_proposer( - min_n=3, max_n=4, - k=2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3])) - assert np.array_equal(result, np.array([4, 1])) + token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]]) + result = get_ngram_proposer(min_n=3, max_n=4, k=2).propose( + sampled_token_ids=[[0]], + req_ids=["0"], + num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), + token_ids_cpu=token_ids_cpu, + spec_decode_unsupported_reqs=(), + ) + assert np.array_equal(result, np.array([[4, 1]])) # Match for both 4-gram and 3-gram. # In this case, the proposer should return the 4-gram match. - result = ngram_proposer(min_n=3, max_n=4, k=2).propose( - context_token_ids=np.array([2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4])) - assert np.array_equal(result, np.array([1, 2])) # Not [5, 1] + token_ids_cpu = np.array([[2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]]) + result = get_ngram_proposer(min_n=3, max_n=4, k=2).propose( + sampled_token_ids=[[0]], + req_ids=["0"], + num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), + token_ids_cpu=token_ids_cpu, + spec_decode_unsupported_reqs=(), + ) + assert np.array_equal(result, np.array([[1, 2]])) # Not [5, 1]] # Match for 2-gram and 3-gram, but not 4-gram. - result = ngram_proposer(min_n=2, max_n=4, k=2).propose( - context_token_ids=np.array([3, 4, 5, 2, 3, 4, 1, 2, 3, 4])) - assert np.array_equal(result, np.array([1, 2])) # Not [5, 2] + token_ids_cpu = np.array([[3, 4, 5, 2, 3, 4, 1, 2, 3, 4]]) + result = get_ngram_proposer(min_n=2, max_n=4, k=2).propose( + sampled_token_ids=[[0]], + req_ids=["0"], + num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), + token_ids_cpu=token_ids_cpu, + spec_decode_unsupported_reqs=(), + ) + assert np.array_equal(result, np.array([[1, 2]])) # Not [5, 2]] # Multiple 3-gram matched, but always pick the first one. - result = ngram_proposer( - min_n=3, max_n=3, k=2).propose(context_token_ids=np.array( - [1, 2, 3, 100, 1, 2, 3, 200, 1, 2, 3, 300, 1, 2, 3])) - assert np.array_equal(result, np.array([100, 1])) + token_ids_cpu = np.array( + [[1, 2, 3, 100, 1, 2, 3, 200, 1, 2, 3, 300, 1, 2, 3]]) + result = get_ngram_proposer(min_n=3, max_n=3, k=2).propose( + sampled_token_ids=[[0]], + req_ids=["0"], + num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), + token_ids_cpu=token_ids_cpu, + spec_decode_unsupported_reqs=(), + ) + assert np.array_equal(result, np.array([[100, 1]])) + + # check empty input + token_ids_cpu = np.array([[]]) + result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose( + sampled_token_ids=[[0]], + req_ids=["0"], + num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), + token_ids_cpu=token_ids_cpu, + spec_decode_unsupported_reqs=(), + ) + assert len(result[0]) == 0 + + # check multibatch input + # first request has 5 tokens and a match + # second request has 3 tokens and no match. Padded with -1 for max len 5 + token_ids_cpu = np.array([[1, 2, 3, 1, 2], [4, 5, 6, -1, -1]]) + result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose( + sampled_token_ids=[[0], [1]], + req_ids=["0", "1"], + num_tokens_no_spec=np.array([5, 3]), + token_ids_cpu=token_ids_cpu, + spec_decode_unsupported_reqs=(), + ) + assert len(result[0]) == 2 + assert np.array_equal(result[0], np.array([3, 1])) + assert np.array_equal(result[1], np.array([])) + + # test if 0 threads available: can happen if TP size > CPU count + ngram_proposer = get_ngram_proposer(min_n=2, max_n=2, k=2) + ngram_proposer.num_numba_thread_available = 0 + # set max_model_len to 2 * threshold to ensure multithread is used + num_tokens_threshold = ngram_proposer.num_tokens_threshold + ngram_proposer.max_model_len = 2 * num_tokens_threshold + # using multibatch test + middle_integer = num_tokens_threshold // 2 + input_1 = [_ for _ in range(num_tokens_threshold)] + input_1 += [middle_integer, middle_integer + 1] + input_2 = [-1] * len(input_1) + input_2[:3] = [4, 5, 6] + token_ids_cpu = np.array([input_1, input_2]) + result = ngram_proposer.propose( + sampled_token_ids=[[0], [1]], + req_ids=["0", "1"], + num_tokens_no_spec=np.array([len(input_1), 3]), + token_ids_cpu=token_ids_cpu, + spec_decode_unsupported_reqs=(), + ) + assert len(result[0]) == 2 + assert np.array_equal(result[0], + np.array([middle_integer + 2, middle_integer + 3])) + assert np.array_equal(result[1], np.array([])) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index ced5c7a97038..8f0b38ecb34d 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -17,7 +17,7 @@ GREEDY_TEMPERATURE: tl.constexpr = -1 # Maximum number of speculative draft tokens allowed per request in a single # step. This value is chosen to be large enough to handle typical use cases. -MAX_SPEC_LEN = 32 +MAX_SPEC_LEN = 128 class RejectionSampler(nn.Module): diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index b92e396d4536..fd8e0a6fd1d2 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -1,9 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional +import os import numpy as np -from numba import jit +from numba import get_num_threads, jit, njit, prange, set_num_threads from vllm.config import VllmConfig @@ -26,55 +26,174 @@ def __init__(self, vllm_config: VllmConfig): # Maximum length of the model. self.max_model_len = vllm_config.model_config.max_model_len + # Pre-allocate buffers for numba batch propose. + max_num_seqs = vllm_config.scheduler_config.max_num_seqs + self.valid_ngram_draft = np.zeros((max_num_seqs, self.k), + dtype=np.int32) + self.valid_ngram_num_drafts = np.zeros((max_num_seqs), dtype=np.int32) + + # Threshold of total number of tokens in the batch to enable + # multi-threading in numba batch propose. + self.num_tokens_threshold = 8192 + tp_size = vllm_config.parallel_config.tensor_parallel_size + cpu_count = os.cpu_count() + # Max number of threads for numba parallel processing. + if cpu_count: + # Divide by 2 to use physical cores + # and not logical cores (hyper-threading). + # Cap the number of threads to 8 to avoid using too many threads + # since other components like frontend (incl tokenization) + # and Structured Outputs also use multiple threads. + # TODO(ekagra-ranjan): bump up the cap from 1 to 8 + # when TP parallelization for ngram is implemented. + self.num_numba_thread_available = min(1, (cpu_count // 2)) + # Divide by tp_size to ensure each tensor parallel rank + # has some threads since all ranks will run this. + self.num_numba_thread_available //= tp_size + else: + self.num_numba_thread_available = 1 + # Trigger Numba JIT compilation for N-gram proposer. # This usually takes less than 1 second. - self.propose(np.zeros(1024, dtype=np.int32)) + self.propose([[]] * 1024, [""] * 1024, np.zeros(1024, dtype=np.int32), + np.zeros((1024, self.max_model_len), dtype=np.int32), + set()) - def propose( + def batch_propose( self, - context_token_ids: np.ndarray, - ) -> Optional[np.ndarray]: - """Proposes the next sequence of tokens based on n-gram pattern - matching in the context. The function finds matches of the last n - tokens in the previous context, and returns k tokens that followed - that match. + num_requests: int, + valid_ngram_requests: list, + num_tokens_no_spec: np.ndarray, + token_ids_cpu: np.ndarray, + ) -> list[list[int]]: + """Batch version of ngram proposer using numba for acceleration. Args: - context_token_ids: Numpy array of token IDs representing the - context sequence. + valid_ngram_requests: + Set of indices of requests that need ngram proposals. + num_tokens_no_spec: + Numpy array of shape (batch_size,) representing the number + of tokens without speculative tokens for each request. + token_ids_cpu: + Numpy array of shape (batch_size, max_model_len) + representing the token IDs for each request. Returns: - np.ndarray: The sequence of tokens that followed - the matched n-gram in the context. - None: If no matching n-gram pattern is found. - - Example: - If context_token_ids = [1,2,3,4,2,3], min_n = 2, max_n = 3, and - k = 4: - - The last 3 (= max_n) tokens [4,2,3] cannot find a match. - - The last 2 tokens [2,3] will be matched against the previous - 4 tokens [1,2,3,4]. - - Finding a match of [2,3] would return the tokens that - followed that pattern. Here we will return [4,2,3] because - we only have three tokens after the match. + list[list[int]]: + A list where each element is a list of proposed + token IDs for the corresponding request. """ - # TODO(woosuk): Optimize this. - return _find_longest_matched_ngram_and_propose_tokens( - origin_tokens=context_token_ids, - min_ngram=self.min_n, - max_ngram=self.max_n, - max_model_len=self.max_model_len, - k=self.k) + draft_token_ids: list[list[int]] = [] + + # Only run batch propose if there are requests needing ngram proposals. + # avoid calling numba function with empty list which causes error + # ValueError: cannot compute fingerprint of empty list + if num_ngram_requests := len(valid_ngram_requests): + original_num_numba_threads = get_num_threads() + # Ensure we use at least one thread. + # If total tokens is small, using multiple threads + # may slow down due to overhead. + total_tokens = np.sum(num_tokens_no_spec) + if total_tokens >= self.num_tokens_threshold: + final_num_threads = max( + 1, min(self.num_numba_thread_available, + num_ngram_requests)) + set_num_threads(final_num_threads) + else: + set_num_threads(1) + + batch_propose_numba(valid_ngram_requests, num_tokens_no_spec, + token_ids_cpu, self.min_n, self.max_n, + self.max_model_len, self.k, + self.valid_ngram_draft, + self.valid_ngram_num_drafts) + + # Restore original number of threads. + set_num_threads(original_num_numba_threads) + + for i in range(num_requests): + if i in valid_ngram_requests and \ + self.valid_ngram_num_drafts[i] > 0: + draft_token_ids.append(self.valid_ngram_draft[ + i, :self.valid_ngram_num_drafts[i]].tolist()) + else: + draft_token_ids.append([]) + + return draft_token_ids + + def propose( + self, + sampled_token_ids: list[list[int]], + req_ids: list[str], + num_tokens_no_spec: np.ndarray, + token_ids_cpu: np.ndarray, + spec_decode_unsupported_reqs: set, + ) -> list[list[int]]: + + # find which requests need ngram proposals + valid_ngram_requests = [] + for i, sampled_ids in enumerate(sampled_token_ids): + num_sampled_ids = len(sampled_ids) + if not num_sampled_ids: + # Skip speculative decoding. + continue + + # Skip requests that require sampling parameters that are not + # supported with speculative decoding. + req_id = req_ids[i] + if req_id in spec_decode_unsupported_reqs: + continue + + num_tokens = num_tokens_no_spec[i] + if num_tokens >= self.max_model_len: + # Skip requests that have already reached the max model length. + continue + + valid_ngram_requests.append(i) + + draft_token_ids = self.batch_propose( + len(sampled_token_ids), + valid_ngram_requests, + num_tokens_no_spec, + token_ids_cpu, + ) + + return draft_token_ids def load_model(self, *args, **kwargs): # No model to load. pass +@njit(parallel=True) +def batch_propose_numba(valid_ngram_requests: list, + num_tokens_no_spec: np.ndarray, + token_ids_cpu: np.ndarray, min_n: int, max_n: int, + max_model_len: int, k: int, + valid_ngram_draft: np.ndarray, + valid_ngram_num_drafts: np.ndarray): + for i in prange(len(valid_ngram_requests)): + idx = valid_ngram_requests[i] + num_tokens = num_tokens_no_spec[idx] + context_token_ids = token_ids_cpu[idx, :num_tokens] + drafter_output = _find_longest_matched_ngram_and_propose_tokens( + origin_tokens=context_token_ids, + min_ngram=min_n, + max_ngram=max_n, + max_model_len=max_model_len, + k=k) + + valid_ngram_num_drafts[i] = drafter_output.shape[0] + if len(drafter_output): + valid_ngram_draft[i, :drafter_output.shape[0]] = drafter_output + + @jit(nopython=True) -def _find_longest_matched_ngram_and_propose_tokens( - origin_tokens: np.ndarray, min_ngram: int, max_ngram: int, - max_model_len: int, k: int) -> Optional[np.ndarray]: +def _find_longest_matched_ngram_and_propose_tokens(origin_tokens: np.ndarray, + min_ngram: int, + max_ngram: int, + max_model_len: int, + k: int) -> np.ndarray: """ Find the longest n-gram which matches the suffix of the given tokens whose length is within [min_ngram, max_ngram] (inclusive). @@ -84,12 +203,12 @@ def _find_longest_matched_ngram_and_propose_tokens( # Do not generate draft tokens is context is shorter than minimum n-gram total_token = origin_tokens.shape[0] if total_token < min_ngram: - return None + return np.empty((0, ), dtype=origin_tokens.dtype) # Do not generate draft tokens beyond the max model length. k = min(k, max_model_len - total_token) if k <= 0: - return None + return np.empty((0, ), dtype=origin_tokens.dtype) # Flip tokens, and the goal become to find longest ngram # on the rightmost position which matches the prefix with @@ -146,7 +265,7 @@ def _find_longest_matched_ngram_and_propose_tokens( if longest_ngram < min_ngram: # No valid ngram is found - return None + return np.empty((0, ), dtype=origin_tokens.dtype) # Flip the position back, so in origin_tokens, # origin_tokens[total_token-1-position:total_token-1-position+longest_ngram] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index b0cd0f413307..caeb4003c6e8 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2368,8 +2368,11 @@ def propose_draft_token_ids( if self.speculative_config.method == "ngram": assert isinstance(sampled_token_ids, list) assert isinstance(self.drafter, NgramProposer) - draft_token_ids = self.propose_ngram_draft_token_ids( - sampled_token_ids) + draft_token_ids = self.drafter.propose( + sampled_token_ids, self.input_batch.req_ids, + self.input_batch.num_tokens_no_spec, + self.input_batch.token_ids_cpu, + self.input_batch.spec_decode_unsupported_reqs) elif self.speculative_config.method == "medusa": assert isinstance(sampled_token_ids, list) assert isinstance(self.drafter, MedusaProposer) @@ -2476,41 +2479,6 @@ def propose_draft_token_ids( ) return draft_token_ids - def propose_ngram_draft_token_ids( - self, - sampled_token_ids: list[list[int]], - ) -> list[list[int]]: - # TODO(woosuk): Optimize. - req_ids = self.input_batch.req_ids - draft_token_ids: list[list[int]] = [] - for i, sampled_ids in enumerate(sampled_token_ids): - num_sampled_ids = len(sampled_ids) - if not num_sampled_ids: - # Skip speculative decoding. - draft_token_ids.append([]) - continue - - # Skip requests that require sampling parameters that are not - # supported with speculative decoding. - req_id = req_ids[i] - if req_id in self.input_batch.spec_decode_unsupported_reqs: - draft_token_ids.append([]) - continue - - num_tokens = self.input_batch.num_tokens_no_spec[i] - if num_tokens >= self.max_model_len: - # Skip requests that have already reached the max model length. - draft_token_ids.append([]) - continue - - drafter_output = self.drafter.propose( - self.input_batch.token_ids_cpu[i, :num_tokens]) - if drafter_output is None or len(drafter_output) == 0: - draft_token_ids.append([]) - else: - draft_token_ids.append(drafter_output.tolist()) - return draft_token_ids - def update_config(self, overrides: dict[str, Any]) -> None: allowed_config_names = {"load_config", "model_config"} for config_name, config_overrides in overrides.items():