Skip to content

Commit ffe5fcc

Browse files
jmkueblerrtourgeman
authored andcommitted
[Spec decode] automatically disable mm for text-only draft models (vllm-project#25667)
Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
1 parent 522c554 commit ffe5fcc

2 files changed

Lines changed: 83 additions & 67 deletions

File tree

tests/v1/e2e/test_spec_decode.py

Lines changed: 69 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import pytest
99
import torch
1010

11-
from tests.utils import get_attn_backend_list_based_on_platform
11+
from tests.utils import get_attn_backend_list_based_on_platform, large_gpu_mark
1212
from vllm import LLM, SamplingParams
1313
from vllm.assets.base import VLLM_S3_BUCKET_URL
1414
from vllm.assets.image import VLM_IMAGES_DIR
@@ -88,69 +88,66 @@ def test_ngram_correctness(
8888
Compare the outputs of an original LLM and a speculative LLM
8989
should be the same when using ngram speculative decoding.
9090
'''
91-
with monkeypatch.context() as m:
92-
m.setenv("VLLM_USE_V1", "1")
93-
test_prompts = get_test_prompts(mm_enabled=False)
94-
95-
ref_llm = LLM(model=model_name, max_model_len=1024)
96-
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
97-
del ref_llm
98-
torch.cuda.empty_cache()
99-
cleanup_dist_env_and_memory()
100-
101-
spec_llm = LLM(
102-
model=model_name,
103-
speculative_config={
104-
"method": "ngram",
105-
"prompt_lookup_max": 5,
106-
"prompt_lookup_min": 3,
107-
"num_speculative_tokens": 3,
108-
},
109-
max_model_len=1024,
110-
)
111-
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
112-
matches = 0
113-
misses = 0
114-
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
115-
if ref_output.outputs[0].text == spec_output.outputs[0].text:
116-
matches += 1
117-
else:
118-
misses += 1
119-
print(f"ref_output: {ref_output.outputs[0].text}")
120-
print(f"spec_output: {spec_output.outputs[0].text}")
121-
122-
# Heuristic: expect at least 66% of the prompts to match exactly
123-
# Upon failure, inspect the outputs to check for inaccuracy.
124-
assert matches >= int(0.66 * len(ref_outputs))
125-
del spec_llm
126-
torch.cuda.empty_cache()
127-
cleanup_dist_env_and_memory()
128-
129-
130-
@pytest.mark.parametrize(["model_setup", "mm_enabled"], [
131-
(("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False),
132-
(("eagle", "meta-llama/Llama-3.1-8B-Instruct",
133-
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
134-
(("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
135-
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False),
136-
pytest.param(
137-
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
138-
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
139-
False,
140-
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
141-
pytest.param(
142-
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
143-
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
144-
True,
145-
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
146-
(("eagle", "eagle618/deepseek-v3-random",
147-
"eagle618/eagle-deepseek-v3-random", 1), False),
148-
],
149-
ids=[
150-
"qwen3_eagle3", "llama3_eagle", "llama3_eagle3",
151-
"llama4_eagle", "llama4_eagle_mm",
152-
"deepseek_eagle"
153-
])
91+
test_prompts = get_test_prompts(mm_enabled=False)
92+
93+
ref_llm = LLM(model=model_name, max_model_len=1024)
94+
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
95+
del ref_llm
96+
torch.cuda.empty_cache()
97+
cleanup_dist_env_and_memory()
98+
99+
spec_llm = LLM(
100+
model=model_name,
101+
speculative_config={
102+
"method": "ngram",
103+
"prompt_lookup_max": 5,
104+
"prompt_lookup_min": 3,
105+
"num_speculative_tokens": 3,
106+
},
107+
max_model_len=1024,
108+
)
109+
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
110+
matches = 0
111+
misses = 0
112+
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
113+
if ref_output.outputs[0].text == spec_output.outputs[0].text:
114+
matches += 1
115+
else:
116+
misses += 1
117+
print(f"ref_output: {ref_output.outputs[0].text}")
118+
print(f"spec_output: {spec_output.outputs[0].text}")
119+
120+
# Heuristic: expect at least 66% of the prompts to match exactly
121+
# Upon failure, inspect the outputs to check for inaccuracy.
122+
assert matches >= int(0.66 * len(ref_outputs))
123+
del spec_llm
124+
torch.cuda.empty_cache()
125+
cleanup_dist_env_and_memory()
126+
127+
128+
@pytest.mark.parametrize(
129+
["model_setup", "mm_enabled"],
130+
[
131+
(("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False),
132+
(("eagle", "meta-llama/Llama-3.1-8B-Instruct",
133+
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
134+
(("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
135+
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False),
136+
pytest.param(("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
137+
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
138+
False,
139+
marks=large_gpu_mark(min_gb=80)), # works on 4x H100
140+
pytest.param(("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
141+
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
142+
True,
143+
marks=large_gpu_mark(min_gb=80)), # works on 4x H100
144+
(("eagle", "eagle618/deepseek-v3-random",
145+
"eagle618/eagle-deepseek-v3-random", 1), False),
146+
],
147+
ids=[
148+
"qwen3_eagle3", "llama3_eagle", "llama3_eagle3", "llama4_eagle",
149+
"llama4_eagle_mm", "deepseek_eagle"
150+
])
154151
@pytest.mark.parametrize("attn_backend",
155152
get_attn_backend_list_based_on_platform())
156153
def test_eagle_correctness(
@@ -174,9 +171,14 @@ def test_eagle_correctness(
174171
model_setup: (method, model_name, eagle_model_name, tp_size)
175172
'''
176173
with monkeypatch.context() as m:
177-
m.setenv("VLLM_USE_V1", "1")
178-
m.setenv("VLLM_MLA_DISABLE", "1")
179-
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
174+
if "Llama-4-Scout" in model_setup[1] and attn_backend == "FLASH_ATTN":
175+
# Scout requires default backend selection
176+
# because vision encoder has head_dim 88 being incompatible
177+
# with FLASH_ATTN and needs to fall back to Flex Attn
178+
pass
179+
else:
180+
m.setenv("VLLM_MLA_DISABLE", "1")
181+
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
180182

181183
if (attn_backend == "TRITON_ATTN" and not current_platform.is_rocm()):
182184
pytest.skip("TRITON_ATTN does not support "

vllm/v1/spec_decode/eagle.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -804,6 +804,20 @@ def load_model(self, target_model: nn.Module) -> None:
804804

805805
self.attn_layer_names = list(draft_attn_layer_names)
806806

807+
if self.is_multimodal_model:
808+
# Even if the target model is multimodal, we can also use
809+
# text-only draft models
810+
try:
811+
dummy_input_ids = torch.tensor([[1]],
812+
device=self.input_ids.device)
813+
self.model.get_input_embeddings(dummy_input_ids,
814+
multimodal_embeddings=None)
815+
except (NotImplementedError, AttributeError, TypeError):
816+
logger.warning(
817+
"Draft model does not support multimodal inputs, "
818+
"falling back to text-only mode")
819+
self.is_multimodal_model = False
820+
807821
if supports_multimodal(target_model):
808822
# handle multimodality
809823
self.model.config.image_token_index = (

0 commit comments

Comments
 (0)