Skip to content
28 changes: 21 additions & 7 deletions tests/v1/e2e/test_spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,9 @@ def test_speculators_model_integration(


@pytest.mark.parametrize(
["model_setup", "mm_enabled"],
["model_setup", "mm_enabled", "chunked_prefill_enabled"],
[
(("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False),
(("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False, False),
pytest.param(
(
"eagle3",
Expand All @@ -213,19 +213,22 @@ def test_speculators_model_integration(
1,
),
False,
False,
marks=pytest.mark.skip(
reason="Skipping due to its head_dim not being a a multiple of 32"
),
),
(
pytest.param(
(
"eagle",
"meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B",
1,
),
False,
),
True,
marks=large_gpu_mark(min_gb=40),
), # works on 4x H100
(
(
"eagle3",
Expand All @@ -234,6 +237,7 @@ def test_speculators_model_integration(
1,
),
False,
False,
),
pytest.param(
(
Expand All @@ -243,6 +247,7 @@ def test_speculators_model_integration(
4,
),
False,
False,
marks=large_gpu_mark(min_gb=80),
), # works on 4x H100
pytest.param(
Expand All @@ -253,6 +258,7 @@ def test_speculators_model_integration(
4,
),
True,
True,
marks=large_gpu_mark(min_gb=80),
), # works on 4x H100
(
Expand All @@ -263,6 +269,7 @@ def test_speculators_model_integration(
1,
),
False,
False,
),
],
ids=[
Expand All @@ -281,6 +288,7 @@ def test_eagle_correctness(
sampling_config: SamplingParams,
model_setup: tuple[str, str, str, int],
mm_enabled: bool,
chunked_prefill_enabled: bool,
attn_backend: str,
):
if attn_backend == "TREE_ATTN":
Expand Down Expand Up @@ -317,9 +325,13 @@ def test_eagle_correctness(
m.setenv("VLLM_ROCM_USE_AITER", "1")

method, model_name, spec_model_name, tp_size = model_setup
max_model_len = 2048
max_num_batched_tokens = max_model_len
if chunked_prefill_enabled:
max_num_batched_tokens = 128

ref_llm = LLM(
model=model_name, max_model_len=2048, tensor_parallel_size=tp_size
model=model_name, max_model_len=max_model_len, tensor_parallel_size=tp_size
)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
Expand All @@ -334,9 +346,11 @@ def test_eagle_correctness(
"method": method,
"model": spec_model_name,
"num_speculative_tokens": 3,
"max_model_len": 2048,
"max_model_len": max_model_len,
},
max_model_len=2048,
max_model_len=max_model_len,
max_num_batched_tokens=max_num_batched_tokens,
enable_chunked_prefill=chunked_prefill_enabled,
)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0
Expand Down