Skip to content

Commit 3af88df

Browse files
committed
Add mlp speculator test
1 parent 26124b3 commit 3af88df

File tree

1 file changed

+44
-1
lines changed

1 file changed

+44
-1
lines changed

tests/spec_decode/e2e/test_mlp_correctness.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
import pytest
2323

24-
from .conftest import run_greedy_equality_correctness_test
24+
from .conftest import run_greedy_equality_correctness_test, run_equality_correctness_test
2525

2626
# main model
2727
MAIN_MODEL = "JackFram/llama-160m"
@@ -77,6 +77,49 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
7777
force_output_len=True)
7878

7979

80+
@pytest.mark.parametrize(
81+
"common_llm_kwargs",
82+
[{
83+
# Skip cuda graph recording for fast test.
84+
"enforce_eager": True,
85+
86+
# Required for spec decode.
87+
"use_v2_block_manager": True,
88+
89+
# Print spec metrics.
90+
"disable_log_stats": False,
91+
92+
# Precision
93+
"dtype": PRECISION,
94+
95+
# Main model
96+
"model": MAIN_MODEL,
97+
}])
98+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
99+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
100+
@pytest.mark.parametrize("test_llm_kwargs", [
101+
{
102+
"speculative_model": SPEC_MODEL,
103+
},
104+
])
105+
@pytest.mark.parametrize("output_len", [
106+
128,
107+
])
108+
@pytest.mark.parametrize("batch_size", [1, 32])
109+
@pytest.mark.parametrize("temperature", [0.1, 1.0])
110+
@pytest.mark.parametrize("seed", [1])
111+
def test_mlp_e2e_seeded_correctness(baseline_llm_generator, test_llm_generator,
112+
batch_size: int, output_len: int, temperature: float):
113+
"""Verify seeded runs produce the same output."""
114+
run_equality_correctness_test(baseline_llm_generator,
115+
baseline_llm_generator,
116+
batch_size,
117+
max_output_len=output_len,
118+
temperature=temperature,
119+
seeded=True,
120+
force_output_len=True)
121+
122+
80123
@pytest.mark.parametrize(
81124
"common_llm_kwargs",
82125
[{

0 commit comments

Comments
 (0)