66import pytest
77import torch
88
9+ from tests .utils import get_attn_backend_list_based_on_platform
910from tests .v1 .attention .utils import (BatchSpec , _Backend ,
1011 create_common_attn_metadata ,
1112 create_standard_kv_cache_spec ,
@@ -120,17 +121,28 @@ def test_prepare_inputs():
120121 assert torch .equal (token_indices , expected_token_indices )
121122
122123
123- @pytest .mark .parametrize ("method,proposer_helper" , [
124- ("eagle" , lambda k : _create_proposer ("eagle" , k )),
125- ("eagle3" , lambda k : _create_proposer ("eagle3" , k )),
126- ])
124+ @pytest .mark .parametrize ("method" , ["eagle" , "eagle3" ])
125+ @pytest .mark .parametrize ("attn_backend" ,
126+ get_attn_backend_list_based_on_platform ())
127127@pytest .mark .parametrize ("pp_size" , [1 , 2 ])
128128@pytest .mark .parametrize ("use_distinct_embed_tokens" , [True , False ])
129129@mock .patch ('vllm.v1.spec_decode.eagle.get_pp_group' )
130130@mock .patch ('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config' )
131131@mock .patch ('vllm.v1.spec_decode.eagle.get_model' )
132132def test_load_model (mock_get_model , mock_get_layers , mock_get_pp_group , method ,
133- proposer_helper , pp_size , use_distinct_embed_tokens ):
133+ attn_backend , pp_size , use_distinct_embed_tokens ,
134+ monkeypatch ):
135+
136+ monkeypatch .setenv ("VLLM_ATTENTION_BACKEND" , attn_backend )
137+
138+ if (attn_backend == "TRITON_ATTN_VLLM_V1"
139+ and not current_platform .is_rocm ()):
140+ pytest .skip ("TRITON_ATTN_VLLM_V1 does not support "
141+ "multi-token eagle spec decode on current platform" )
142+
143+ if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform .is_rocm ():
144+ monkeypatch .setenv ("VLLM_ROCM_USE_AITER" , "1" )
145+
134146 # Setup draft model mock
135147 mock_model = mock .MagicMock ()
136148 if use_distinct_embed_tokens :
@@ -177,7 +189,7 @@ class _TargetModelStub(LlamaForCausalLM):
177189 target_model .lm_head = mock .MagicMock ()
178190
179191 # Create proposer using the helper function
180- proposer = proposer_helper ( k = 8 )
192+ proposer = _create_proposer ( method , k = 8 )
181193
182194 # Call the method under test
183195 proposer .load_model (target_model )
@@ -201,10 +213,22 @@ class _TargetModelStub(LlamaForCausalLM):
201213 target_model .model .embed_tokens
202214
203215
216+ @pytest .mark .parametrize ("method" , ["eagle" , "eagle3" ])
217+ @pytest .mark .parametrize ("attn_backend" ,
218+ get_attn_backend_list_based_on_platform ())
204219@pytest .mark .parametrize ("num_speculative_tokens" , [1 , 3 , 8 ])
205- @pytest .mark .parametrize ("backend" ,
206- [_Backend .FLASH_ATTN_VLLM_V1 , _Backend .TREE_ATTN ])
207- def test_propose (num_speculative_tokens , backend ):
220+ def test_propose (method , attn_backend , num_speculative_tokens , monkeypatch ):
221+
222+ monkeypatch .setenv ("VLLM_ATTENTION_BACKEND" , attn_backend )
223+
224+ if (attn_backend == "TRITON_ATTN_VLLM_V1"
225+ and not current_platform .is_rocm ()):
226+ pytest .skip ("TRITON_ATTN_VLLM_V1 does not support "
227+ "multi-token eagle spec decode on current platform" )
228+
229+ if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform .is_rocm ():
230+ monkeypatch .setenv ("VLLM_ROCM_USE_AITER" , "1" )
231+
208232 # Use GPU device
209233 device = torch .device (current_platform .device_type )
210234
@@ -303,7 +327,18 @@ def create_deterministic_logits(token_ids):
303327 device = device )
304328 sampling_metadata = mock .MagicMock ()
305329
306- attn_metadata_builder_cls , _ = get_attention_backend (backend )
330+ if attn_backend == "FLASH_ATTN_VLLM_V1" :
331+ attn_metadata_builder_cls , _ = get_attention_backend (
332+ _Backend .FLASH_ATTN_VLLM_V1 )
333+ elif attn_backend == "TRITON_ATTN_VLLM_V1" :
334+ attn_metadata_builder_cls , _ = get_attention_backend (
335+ _Backend .TRITON_ATTN_VLLM_V1 )
336+ elif attn_backend == "TREE_ATTN" :
337+ attn_metadata_builder_cls , _ = get_attention_backend (
338+ _Backend .TREE_ATTN )
339+ else :
340+ raise ValueError (f"Unsupported attention backend: { attn_backend } " )
341+
307342 attn_metadata_builder = attn_metadata_builder_cls (
308343 kv_cache_spec = create_standard_kv_cache_spec (proposer .vllm_config ),
309344 layer_names = proposer .attn_layer_names ,
0 commit comments