Skip to content

Commit 32ac8d7

Browse files
LiuXiaoxuanPKUDefTruth
authored andcommitted
[V1] [Spec Decode] Fix ngram tests (vllm-project#14878)
Signed-off-by: DefTruth <[email protected]>
1 parent 1c59947 commit 32ac8d7

File tree

1 file changed

+29
-24
lines changed

1 file changed

+29
-24
lines changed

tests/v1/spec_decode/test_ngram.py

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,37 @@
11
# SPDX-License-Identifier: Apache-2.0
2-
import pytest
32

4-
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
5-
from vllm.v1.utils import ConstantList
3+
import numpy as np
64

5+
from vllm.v1.spec_decode.ngram_proposer import (_find_subarray_kmp,
6+
_kmp_lps_array)
77

8-
@pytest.fixture
9-
def proposer():
10-
return NgramProposer()
118

9+
def test_kmp_lps_array():
10+
np.testing.assert_array_equal(_kmp_lps_array(np.array([])), np.array([]))
11+
np.testing.assert_array_equal(_kmp_lps_array(np.array([1])), np.array([0]))
12+
np.testing.assert_array_equal(_kmp_lps_array(np.array([1, 1, 1])),
13+
np.array([0, 1, 2]))
14+
np.testing.assert_array_equal(_kmp_lps_array(np.array([1, 2, 3, 4])),
15+
np.array([0, 0, 0, 0]))
16+
np.testing.assert_array_equal(_kmp_lps_array(np.array([1, 2, 1, 2, 3])),
17+
np.array([0, 0, 1, 2, 0]))
1218

13-
def test_kmp_lps_array(proposer):
14-
assert proposer._kmp_lps_array([]) == []
15-
assert proposer._kmp_lps_array([1]) == [0]
16-
assert proposer._kmp_lps_array([1, 1, 1]) == [0, 1, 2]
17-
assert proposer._kmp_lps_array([1, 2, 3, 4]) == [0, 0, 0, 0]
18-
assert proposer._kmp_lps_array([1, 2, 1, 2, 3]) == [0, 0, 1, 2, 0]
1919

20-
21-
def test_find_subarray_kmp(proposer):
22-
X = ConstantList([1, 2, 3, 4, 1, 2, 3, 5, 6])
23-
assert proposer._find_subarray_kmp(X, 2, 2) is None
24-
X = ConstantList([1, 2, 3, 4, 1, 2, 3])
25-
assert proposer._find_subarray_kmp(X, 2, 3) == [4, 1, 2]
26-
assert proposer._find_subarray_kmp(X, 2, 2) == [4, 1]
27-
assert proposer._find_subarray_kmp(X, 1, 3) == [4, 1, 2]
28-
assert proposer._find_subarray_kmp(X, 1, 2) == [4, 1]
29-
X = ConstantList([1, 3, 6, 2, 3, 4, 1, 2, 3])
30-
assert proposer._find_subarray_kmp(X, 2, 3) == [4, 1, 2]
20+
def test_find_subarray_kmp():
21+
X = np.array([1, 2, 3, 4, 1, 2, 3, 5, 6])
22+
assert _find_subarray_kmp(X, 2, 2) is None
23+
X = np.array([1, 2, 3, 4, 1, 2, 3])
24+
np.testing.assert_array_equal(_find_subarray_kmp(X, 2, 3),
25+
np.array([4, 1, 2]))
26+
np.testing.assert_array_equal(_find_subarray_kmp(X, 2, 2), np.array([4,
27+
1]))
28+
np.testing.assert_array_equal(_find_subarray_kmp(X, 1, 3),
29+
np.array([4, 1, 2]))
30+
np.testing.assert_array_equal(_find_subarray_kmp(X, 1, 2), np.array([4,
31+
1]))
32+
X = np.array([1, 3, 6, 2, 3, 4, 1, 2, 3])
33+
np.testing.assert_array_equal(_find_subarray_kmp(X, 2, 3),
34+
np.array([4, 1, 2]))
3135
# Return on the first match
32-
assert proposer._find_subarray_kmp(X, 1, 3) == [6, 2, 3]
36+
np.testing.assert_array_equal(_find_subarray_kmp(X, 1, 3),
37+
np.array([6, 2, 3]))

0 commit comments

Comments
 (0)