88import pytest
99import torch
1010
11- from tests .utils import get_attn_backend_list_based_on_platform
11+ from tests .utils import get_attn_backend_list_based_on_platform , large_gpu_mark
1212from vllm import LLM , SamplingParams
1313from vllm .assets .base import VLLM_S3_BUCKET_URL
1414from vllm .assets .image import VLM_IMAGES_DIR
@@ -88,69 +88,66 @@ def test_ngram_correctness(
8888 Compare the outputs of an original LLM and a speculative LLM
8989 should be the same when using ngram speculative decoding.
9090 '''
91- with monkeypatch .context () as m :
92- m .setenv ("VLLM_USE_V1" , "1" )
93- test_prompts = get_test_prompts (mm_enabled = False )
94-
95- ref_llm = LLM (model = model_name , max_model_len = 1024 )
96- ref_outputs = ref_llm .chat (test_prompts , sampling_config )
97- del ref_llm
98- torch .cuda .empty_cache ()
99- cleanup_dist_env_and_memory ()
100-
101- spec_llm = LLM (
102- model = model_name ,
103- speculative_config = {
104- "method" : "ngram" ,
105- "prompt_lookup_max" : 5 ,
106- "prompt_lookup_min" : 3 ,
107- "num_speculative_tokens" : 3 ,
108- },
109- max_model_len = 1024 ,
110- )
111- spec_outputs = spec_llm .chat (test_prompts , sampling_config )
112- matches = 0
113- misses = 0
114- for ref_output , spec_output in zip (ref_outputs , spec_outputs ):
115- if ref_output .outputs [0 ].text == spec_output .outputs [0 ].text :
116- matches += 1
117- else :
118- misses += 1
119- print (f"ref_output: { ref_output .outputs [0 ].text } " )
120- print (f"spec_output: { spec_output .outputs [0 ].text } " )
121-
122- # Heuristic: expect at least 66% of the prompts to match exactly
123- # Upon failure, inspect the outputs to check for inaccuracy.
124- assert matches >= int (0.66 * len (ref_outputs ))
125- del spec_llm
126- torch .cuda .empty_cache ()
127- cleanup_dist_env_and_memory ()
128-
129-
130- @pytest .mark .parametrize (["model_setup" , "mm_enabled" ], [
131- (("eagle3" , "Qwen/Qwen3-8B" , "AngelSlim/Qwen3-8B_eagle3" , 1 ), False ),
132- (("eagle" , "meta-llama/Llama-3.1-8B-Instruct" ,
133- "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" , 1 ), False ),
134- (("eagle3" , "meta-llama/Llama-3.1-8B-Instruct" ,
135- "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" , 1 ), False ),
136- pytest .param (
137- ("eagle" , "meta-llama/Llama-4-Scout-17B-16E-Instruct" ,
138- "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct" , 4 ),
139- False ,
140- marks = pytest .mark .skip (reason = "Skipping due to CI OOM issues" )),
141- pytest .param (
142- ("eagle" , "meta-llama/Llama-4-Scout-17B-16E-Instruct" ,
143- "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct" , 4 ),
144- True ,
145- marks = pytest .mark .skip (reason = "Skipping due to CI OOM issues" )),
146- (("eagle" , "eagle618/deepseek-v3-random" ,
147- "eagle618/eagle-deepseek-v3-random" , 1 ), False ),
148- ],
149- ids = [
150- "qwen3_eagle3" , "llama3_eagle" , "llama3_eagle3" ,
151- "llama4_eagle" , "llama4_eagle_mm" ,
152- "deepseek_eagle"
153- ])
91+ test_prompts = get_test_prompts (mm_enabled = False )
92+
93+ ref_llm = LLM (model = model_name , max_model_len = 1024 )
94+ ref_outputs = ref_llm .chat (test_prompts , sampling_config )
95+ del ref_llm
96+ torch .cuda .empty_cache ()
97+ cleanup_dist_env_and_memory ()
98+
99+ spec_llm = LLM (
100+ model = model_name ,
101+ speculative_config = {
102+ "method" : "ngram" ,
103+ "prompt_lookup_max" : 5 ,
104+ "prompt_lookup_min" : 3 ,
105+ "num_speculative_tokens" : 3 ,
106+ },
107+ max_model_len = 1024 ,
108+ )
109+ spec_outputs = spec_llm .chat (test_prompts , sampling_config )
110+ matches = 0
111+ misses = 0
112+ for ref_output , spec_output in zip (ref_outputs , spec_outputs ):
113+ if ref_output .outputs [0 ].text == spec_output .outputs [0 ].text :
114+ matches += 1
115+ else :
116+ misses += 1
117+ print (f"ref_output: { ref_output .outputs [0 ].text } " )
118+ print (f"spec_output: { spec_output .outputs [0 ].text } " )
119+
120+ # Heuristic: expect at least 66% of the prompts to match exactly
121+ # Upon failure, inspect the outputs to check for inaccuracy.
122+ assert matches >= int (0.66 * len (ref_outputs ))
123+ del spec_llm
124+ torch .cuda .empty_cache ()
125+ cleanup_dist_env_and_memory ()
126+
127+
128+ @pytest .mark .parametrize (
129+ ["model_setup" , "mm_enabled" ],
130+ [
131+ (("eagle3" , "Qwen/Qwen3-8B" , "AngelSlim/Qwen3-8B_eagle3" , 1 ), False ),
132+ (("eagle" , "meta-llama/Llama-3.1-8B-Instruct" ,
133+ "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" , 1 ), False ),
134+ (("eagle3" , "meta-llama/Llama-3.1-8B-Instruct" ,
135+ "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" , 1 ), False ),
136+ pytest .param (("eagle" , "meta-llama/Llama-4-Scout-17B-16E-Instruct" ,
137+ "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct" , 4 ),
138+ False ,
139+ marks = large_gpu_mark (min_gb = 80 )), # works on 4x H100
140+ pytest .param (("eagle" , "meta-llama/Llama-4-Scout-17B-16E-Instruct" ,
141+ "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct" , 4 ),
142+ True ,
143+ marks = large_gpu_mark (min_gb = 80 )), # works on 4x H100
144+ (("eagle" , "eagle618/deepseek-v3-random" ,
145+ "eagle618/eagle-deepseek-v3-random" , 1 ), False ),
146+ ],
147+ ids = [
148+ "qwen3_eagle3" , "llama3_eagle" , "llama3_eagle3" , "llama4_eagle" ,
149+ "llama4_eagle_mm" , "deepseek_eagle"
150+ ])
154151@pytest .mark .parametrize ("attn_backend" ,
155152 get_attn_backend_list_based_on_platform ())
156153def test_eagle_correctness (
@@ -174,9 +171,14 @@ def test_eagle_correctness(
174171 model_setup: (method, model_name, eagle_model_name, tp_size)
175172 '''
176173 with monkeypatch .context () as m :
177- m .setenv ("VLLM_USE_V1" , "1" )
178- m .setenv ("VLLM_MLA_DISABLE" , "1" )
179- m .setenv ("VLLM_ATTENTION_BACKEND" , attn_backend )
174+ if "Llama-4-Scout" in model_setup [1 ] and attn_backend == "FLASH_ATTN" :
175+ # Scout requires default backend selection
176+ # because vision encoder has head_dim 88 being incompatible
177+ # with FLASH_ATTN and needs to fall back to Flex Attn
178+ pass
179+ else :
180+ m .setenv ("VLLM_MLA_DISABLE" , "1" )
181+ m .setenv ("VLLM_ATTENTION_BACKEND" , attn_backend )
180182
181183 if (attn_backend == "TRITON_ATTN" and not current_platform .is_rocm ()):
182184 pytest .skip ("TRITON_ATTN does not support "
0 commit comments