Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 104 additions & 3 deletions benchmarks/benchmark_ngram_proposer.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,31 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc
import time
from unittest import mock

import numpy as np
from tabulate import tabulate

from benchmark_utils import TimeCollector
from vllm.config import ModelConfig, SpeculativeConfig, VllmConfig
from vllm.config import (
CacheConfig,
DeviceConfig,
LoadConfig,
ModelConfig,
ParallelConfig,
SchedulerConfig,
SpeculativeConfig,
VllmConfig,
)
from vllm.platforms import current_platform
from vllm.utils import FlexibleArgumentParser
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner


def main(args):
def benchmark_propose(args):
rows = []
for max_ngram in args.max_ngram:
collector = TimeCollector(TimeCollector.US)
Expand Down Expand Up @@ -69,10 +83,88 @@ def main(args):
)


def benchmark_batched_propose(args):
NUM_SPECULATIVE_TOKENS_NGRAM = 10
PROMPT_LOOKUP_MIN = 5
PROMPT_LOOKUP_MAX = 15
MAX_MODEL_LEN = int(1e7)
DEVICE = current_platform.device_type

model_config = ModelConfig(model="facebook/opt-125m", runner="generate")

speculative_config = SpeculativeConfig(
target_model_config=model_config,
target_parallel_config=ParallelConfig(),
method="ngram",
num_speculative_tokens=NUM_SPECULATIVE_TOKENS_NGRAM,
prompt_lookup_max=PROMPT_LOOKUP_MAX,
prompt_lookup_min=PROMPT_LOOKUP_MIN,
)

vllm_config = VllmConfig(
model_config=model_config,
cache_config=CacheConfig(),
speculative_config=speculative_config,
device_config=DeviceConfig(device=current_platform.device_type),
parallel_config=ParallelConfig(),
load_config=LoadConfig(),
scheduler_config=SchedulerConfig(),
)

# monkey patch vllm.v1.worker.gpu_model_runner.get_pp_group
mock_pp_group = mock.MagicMock()
mock_pp_group.world_size = 1
with mock.patch(
"vllm.v1.worker.gpu_model_runner.get_pp_group", return_value=mock_pp_group
):
runner = GPUModelRunner(vllm_config, DEVICE)

# hack max model len
runner.max_model_len = MAX_MODEL_LEN
runner.drafter.max_model_len = MAX_MODEL_LEN

dummy_input_batch = InputBatch(
max_num_reqs=args.num_req,
max_model_len=MAX_MODEL_LEN,
max_num_batched_tokens=args.num_req * args.num_token,
device=DEVICE,
pin_memory=False,
vocab_size=256000,
block_sizes=[16],
)
dummy_input_batch._req_ids = list(str(id) for id in range(args.num_req))
dummy_input_batch.spec_decode_unsupported_reqs = ()
dummy_input_batch.num_tokens_no_spec = [args.num_token] * args.num_req
dummy_input_batch.token_ids_cpu = np.random.randint(
0, 20, (args.num_req, args.num_token)
)

runner.input_batch = dummy_input_batch

sampled_token_ids = [[0]] * args.num_req

print("Starting benchmark")
# first run is warmup so ignore it
for _ in range(args.num_iteration):
start = time.time()
runner.drafter.propose(
sampled_token_ids,
dummy_input_batch.req_ids,
dummy_input_batch.num_tokens_no_spec,
dummy_input_batch.token_ids_cpu,
dummy_input_batch.spec_decode_unsupported_reqs,
)
end = time.time()
print(f"Iteration time (s): {end - start}")


def invoke_main() -> None:
parser = FlexibleArgumentParser(
description="Benchmark the performance of N-gram speculative decode drafting"
)
parser.add_argument(
"--batched", action="store_true", help="consider time to prepare batch"
) # noqa: E501
parser.add_argument(
"--num-iteration",
type=int,
Expand Down Expand Up @@ -105,8 +197,17 @@ def invoke_main() -> None:
help="Number of speculative tokens to generate",
)
args = parser.parse_args()
main(args)

if not args.batched:
benchmark_propose(args)
else:
benchmark_batched_propose(args)


"""
# Example command lines:
# time python3 benchmarks/benchmark_ngram_proposer.py
# time python3 benchmarks/benchmark_ngram_proposer.py --batched --num-iteration 4 --num-token 1000000 --num-req 128
""" # noqa: E501
if __name__ == "__main__":
invoke_main() # pragma: no cover
142 changes: 114 additions & 28 deletions tests/v1/spec_decode/test_ngram.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@

def test_find_longest_matched_ngram_and_propose_tokens():
tokens = np.array([1, 2, 3, 4, 1, 2, 3, 5, 6])
assert _find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens,
min_ngram=2,
max_ngram=2,
max_model_len=1024,
k=2) is None
result = _find_longest_matched_ngram_and_propose_tokens(
origin_tokens=tokens,
min_ngram=2,
max_ngram=2,
max_model_len=1024,
k=2)
assert len(result) == 0

tokens = np.array([1, 2, 3, 4, 1, 2, 3])
np.testing.assert_array_equal(
Expand Down Expand Up @@ -62,7 +64,7 @@ def test_find_longest_matched_ngram_and_propose_tokens():

def test_ngram_proposer():

def ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer:
def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer:
# Dummy model config. Just to set max_model_len.
model_config = ModelConfig(model="facebook/opt-125m")
return NgramProposer(
Expand All @@ -75,36 +77,120 @@ def ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer:
)))

# No match.
result = ngram_proposer(
min_n=2, max_n=2,
k=2).propose(context_token_ids=np.array([1, 2, 3, 4, 5]))
assert result is None
token_ids_cpu = np.array([[1, 2, 3, 4, 5]])
result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose(
sampled_token_ids=[[0]],
req_ids=["0"],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
)
assert len(result[0]) == 0

# No match for 4-gram.
result = ngram_proposer(
min_n=4, max_n=4,
k=2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]))
assert result is None
token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]])
result = get_ngram_proposer(min_n=4, max_n=4, k=2).propose(
sampled_token_ids=[[0]],
req_ids=["0"],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
)
assert len(result[0]) == 0

# No match for 4-gram but match for 3-gram.
result = ngram_proposer(
min_n=3, max_n=4,
k=2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]))
assert np.array_equal(result, np.array([4, 1]))
token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]])
result = get_ngram_proposer(min_n=3, max_n=4, k=2).propose(
sampled_token_ids=[[0]],
req_ids=["0"],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
)
assert np.array_equal(result, np.array([[4, 1]]))

# Match for both 4-gram and 3-gram.
# In this case, the proposer should return the 4-gram match.
result = ngram_proposer(min_n=3, max_n=4, k=2).propose(
context_token_ids=np.array([2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]))
assert np.array_equal(result, np.array([1, 2])) # Not [5, 1]
token_ids_cpu = np.array([[2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]])
result = get_ngram_proposer(min_n=3, max_n=4, k=2).propose(
sampled_token_ids=[[0]],
req_ids=["0"],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
)
assert np.array_equal(result, np.array([[1, 2]])) # Not [5, 1]]

# Match for 2-gram and 3-gram, but not 4-gram.
result = ngram_proposer(min_n=2, max_n=4, k=2).propose(
context_token_ids=np.array([3, 4, 5, 2, 3, 4, 1, 2, 3, 4]))
assert np.array_equal(result, np.array([1, 2])) # Not [5, 2]
token_ids_cpu = np.array([[3, 4, 5, 2, 3, 4, 1, 2, 3, 4]])
result = get_ngram_proposer(min_n=2, max_n=4, k=2).propose(
sampled_token_ids=[[0]],
req_ids=["0"],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
)
assert np.array_equal(result, np.array([[1, 2]])) # Not [5, 2]]

# Multiple 3-gram matched, but always pick the first one.
result = ngram_proposer(
min_n=3, max_n=3, k=2).propose(context_token_ids=np.array(
[1, 2, 3, 100, 1, 2, 3, 200, 1, 2, 3, 300, 1, 2, 3]))
assert np.array_equal(result, np.array([100, 1]))
token_ids_cpu = np.array(
[[1, 2, 3, 100, 1, 2, 3, 200, 1, 2, 3, 300, 1, 2, 3]])
result = get_ngram_proposer(min_n=3, max_n=3, k=2).propose(
sampled_token_ids=[[0]],
req_ids=["0"],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
)
assert np.array_equal(result, np.array([[100, 1]]))

# check empty input
token_ids_cpu = np.array([[]])
result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose(
sampled_token_ids=[[0]],
req_ids=["0"],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
)
assert len(result[0]) == 0

# check multibatch input
# first request has 5 tokens and a match
# second request has 3 tokens and no match. Padded with -1 for max len 5
token_ids_cpu = np.array([[1, 2, 3, 1, 2], [4, 5, 6, -1, -1]])
result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose(
sampled_token_ids=[[0], [1]],
req_ids=["0", "1"],
num_tokens_no_spec=np.array([5, 3]),
token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
)
assert len(result[0]) == 2
assert np.array_equal(result[0], np.array([3, 1]))
assert np.array_equal(result[1], np.array([]))

# test if 0 threads available: can happen if TP size > CPU count
ngram_proposer = get_ngram_proposer(min_n=2, max_n=2, k=2)
ngram_proposer.num_numba_thread_available = 0
# set max_model_len to 2 * threshold to ensure multithread is used
num_tokens_threshold = ngram_proposer.num_tokens_threshold
ngram_proposer.max_model_len = 2 * num_tokens_threshold
# using multibatch test
middle_integer = num_tokens_threshold // 2
input_1 = [_ for _ in range(num_tokens_threshold)]
input_1 += [middle_integer, middle_integer + 1]
input_2 = [-1] * len(input_1)
input_2[:3] = [4, 5, 6]
token_ids_cpu = np.array([input_1, input_2])
result = ngram_proposer.propose(
sampled_token_ids=[[0], [1]],
req_ids=["0", "1"],
num_tokens_no_spec=np.array([len(input_1), 3]),
token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
)
assert len(result[0]) == 2
assert np.array_equal(result[0],
np.array([middle_integer + 2, middle_integer + 3]))
assert np.array_equal(result[1], np.array([]))
2 changes: 1 addition & 1 deletion vllm/v1/sample/rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
GREEDY_TEMPERATURE: tl.constexpr = -1
# Maximum number of speculative draft tokens allowed per request in a single
# step. This value is chosen to be large enough to handle typical use cases.
MAX_SPEC_LEN = 32
MAX_SPEC_LEN = 128


class RejectionSampler(nn.Module):
Expand Down
Loading