11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3- """Tests for v1 MLA backends without GPUModelRunner dependency."""
3+ """Tests for v1 MLA backends without GPUModelRunner dependency.
4+
5+ Known Issues:
6+ - FLASH_ATTN_MLA backend occasionally produces NaN values in
7+ test_backend_correctness[mixed_small] when run after
8+ test_backend_correctness[small_prefill], but passes when run alone.
9+ """
410
511import pytest
612import torch
1420)
1521from vllm import _custom_ops as ops
1622from vllm .attention .backends .registry import _Backend
23+ from vllm .attention .ops .flashmla import is_flashmla_dense_supported
24+ from vllm .config .vllm import set_current_vllm_config
1725from vllm .utils import STR_DTYPE_TO_TORCH_DTYPE , cdiv
1826from vllm .v1 .attention .backends .utils import CommonAttentionMetadata
1927from vllm .v1 .kv_cache_interface import FullAttentionSpec
2937if not torch .cuda .is_available () or torch .cuda .get_device_properties (0 ).major < 10 :
3038 BACKENDS_TO_TEST .remove (_Backend .CUTLASS_MLA )
3139
40+ # Remove FLASHMLA from the list if not supported
41+ if not is_flashmla_dense_supported ()[0 ]:
42+ BACKENDS_TO_TEST .remove (_Backend .FLASHMLA )
43+
3244torch .manual_seed (42 )
3345
3446
@@ -66,6 +78,12 @@ def _convert_dtype_to_torch(dtype):
6678 "large_prefill" : BatchSpec (seq_lens = [4096 ] * 8 , query_lens = [32 ] * 8 ),
6779 "single_decode" : BatchSpec (seq_lens = [1024 ], query_lens = [1 ]),
6880 "single_prefill" : BatchSpec (seq_lens = [1024 ], query_lens = [64 ]),
81+ "spec_decode_small" : BatchSpec (
82+ seq_lens = [128 , 256 , 512 , 1024 ], query_lens = [4 , 4 , 4 , 4 ]
83+ ),
84+ "spec_decode_medium" : BatchSpec (
85+ seq_lens = [512 , 1024 , 2048 , 512 , 1024 , 2048 ], query_lens = [8 , 8 , 8 , 8 , 8 , 8 ]
86+ ),
6987}
7088
7189
@@ -239,61 +257,64 @@ def run_attention_backend(
239257
240258 builder_cls , impl_cls = try_get_attention_backend (backend )
241259
242- # Build metadata
243- builder = builder_cls (kv_cache_spec , layer_names , vllm_config , device )
244- attn_metadata = builder .build (
245- common_prefix_len = 0 ,
246- common_attn_metadata = common_attn_metadata ,
247- )
260+ # Set the current vllm config so that get_current_vllm_config() works
261+ # in the backend implementations
262+ with set_current_vllm_config (vllm_config ):
263+ # Build metadata
264+ builder = builder_cls (kv_cache_spec , layer_names , vllm_config , device )
265+ attn_metadata = builder .build (
266+ common_prefix_len = 0 ,
267+ common_attn_metadata = common_attn_metadata ,
268+ )
248269
249- # Instantiate MLA implementation
250- num_heads = vllm_config .model_config .get_num_attention_heads (
251- vllm_config .parallel_config
252- )
253- num_kv_heads = vllm_config .model_config .get_num_kv_heads (
254- vllm_config .parallel_config
255- )
256- head_size = vllm_config .model_config .get_head_size ()
257- scale = 1.0 / (head_size ** 0.5 )
258- impl = impl_cls (
259- num_heads = num_heads ,
260- head_size = head_size ,
261- scale = scale ,
262- num_kv_heads = num_kv_heads ,
263- alibi_slopes = None ,
264- sliding_window = None ,
265- kv_cache_dtype = "auto" ,
266- logits_soft_cap = None ,
267- attn_type = "decoder" ,
268- kv_sharing_target_layer_name = None ,
269- q_lora_rank = None ,
270- kv_lora_rank = kv_lora_rank ,
271- qk_nope_head_dim = qk_nope_head_dim ,
272- qk_rope_head_dim = qk_rope_head_dim ,
273- qk_head_dim = qk_nope_head_dim + qk_rope_head_dim ,
274- v_head_dim = v_head_dim ,
275- kv_b_proj = mock_kv_b_proj ,
276- )
270+ # Instantiate MLA implementation
271+ num_heads = vllm_config .model_config .get_num_attention_heads (
272+ vllm_config .parallel_config
273+ )
274+ num_kv_heads = vllm_config .model_config .get_num_kv_heads (
275+ vllm_config .parallel_config
276+ )
277+ head_size = vllm_config .model_config .get_head_size ()
278+ scale = 1.0 / (head_size ** 0.5 )
279+ impl = impl_cls (
280+ num_heads = num_heads ,
281+ head_size = head_size ,
282+ scale = scale ,
283+ num_kv_heads = num_kv_heads ,
284+ alibi_slopes = None ,
285+ sliding_window = None ,
286+ kv_cache_dtype = "auto" ,
287+ logits_soft_cap = None ,
288+ attn_type = "decoder" ,
289+ kv_sharing_target_layer_name = None ,
290+ q_lora_rank = None ,
291+ kv_lora_rank = kv_lora_rank ,
292+ qk_nope_head_dim = qk_nope_head_dim ,
293+ qk_rope_head_dim = qk_rope_head_dim ,
294+ qk_head_dim = qk_nope_head_dim + qk_rope_head_dim ,
295+ v_head_dim = v_head_dim ,
296+ kv_b_proj = mock_kv_b_proj ,
297+ )
277298
278- # Process weights to create W_UK_T and W_UV attributes needed by MLA
279- act_dtype = _convert_dtype_to_torch (vllm_config .model_config .dtype )
280- impl .process_weights_after_loading (act_dtype )
299+ # Process weights to create W_UK_T and W_UV attributes needed by MLA
300+ act_dtype = _convert_dtype_to_torch (vllm_config .model_config .dtype )
301+ impl .process_weights_after_loading (act_dtype )
281302
282- # Create mock layer and output buffer
283- mock_layer = MockAttentionLayer (device )
284- num_tokens = query .shape [0 ]
285- output = torch .empty (
286- num_tokens , num_heads * v_head_dim , dtype = query .dtype , device = query .device
287- )
303+ # Create mock layer and output buffer
304+ mock_layer = MockAttentionLayer (device )
305+ num_tokens = query .shape [0 ]
306+ output = torch .empty (
307+ num_tokens , num_heads * v_head_dim , dtype = query .dtype , device = query .device
308+ )
288309
289- # Run forward pass
290- # NOTE: The query, key, and value are already shaped correctly
291- # in the calling test function.
292- output = impl .forward (
293- mock_layer , query , kv_c , k_pe , kv_cache , attn_metadata , output = output
294- )
310+ # Run forward pass
311+ # NOTE: The query, key, and value are already shaped correctly
312+ # in the calling test function.
313+ output = impl .forward (
314+ mock_layer , query , kv_c , k_pe , kv_cache , attn_metadata , output = output
315+ )
295316
296- return output
317+ return output
297318
298319
299320@pytest .mark .parametrize (
@@ -309,6 +330,8 @@ def run_attention_backend(
309330 "large_prefill" ,
310331 "single_decode" ,
311332 "single_prefill" ,
333+ "spec_decode_small" ,
334+ "spec_decode_medium" ,
312335 ],
313336)
314337@pytest .mark .parametrize ("model" , ["deepseek-ai/DeepSeek-V2-Lite-Chat" ])
@@ -328,10 +351,39 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
328351 simulated paged KV cache.
329352 5. Comparing the vLLM backend's output to the ground-truth SDPA output.
330353 """
354+ from vllm .v1 .attention .backends .mla .common import QueryLenSupport
355+
331356 batch_spec = BATCH_SPECS [batch_spec_name ]
357+ is_spec_decode_test = batch_spec_name .startswith ("spec_decode" )
358+ spec_decode_backends = {_Backend .FLASH_ATTN_MLA , _Backend .FLASHMLA }
359+
360+ block_size = 16
361+ required_blocks = sum (
362+ (seq_len + block_size - 1 ) // block_size for seq_len in batch_spec .seq_lens
363+ )
364+ # Add 1 for null block at index 0, and some buffer
365+ num_gpu_blocks = required_blocks + 1 + 100
366+
332367 vllm_config = create_vllm_config (
333- model_name = model , max_model_len = max (batch_spec .seq_lens ), num_gpu_blocks = 2048
368+ model_name = model ,
369+ max_model_len = max (batch_spec .seq_lens ),
370+ num_gpu_blocks = num_gpu_blocks ,
371+ block_size = block_size ,
334372 )
373+
374+ # For spec decode tests, add a speculative_config to set the reorder_batch_threshold
375+ if is_spec_decode_test :
376+ from vllm .config import SpeculativeConfig
377+
378+ # Get the query length from the batch spec (they should all be uniform)
379+ query_len = batch_spec .query_lens [0 ]
380+ # Set num_speculative_tokens to query_len - 1
381+ # (since threshold is 1 + num_spec_tokens)
382+ # Use ngram method which doesn't require a draft model
383+ vllm_config .speculative_config = SpeculativeConfig (
384+ method = "ngram" , num_speculative_tokens = query_len - 1
385+ )
386+
335387 device = torch .device ("cuda:0" )
336388
337389 kv_cache_spec = create_standard_kv_cache_spec (vllm_config )
@@ -395,11 +447,37 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
395447 # K_PE (rope component): [s_len, 1, qk_rope_head_dim]
396448 k_pe_full = torch .randn (s_len , 1 , qk_rope_head_dim , dtype = dtype , device = device )
397449
398- # Determine if this is decode or prefill
450+ # Determine if this sequence uses the decode pipeline or prefill
451+ # pipeline for each backend
452+ # NOTE: For spec decode tests with uniform query_len > 1, backends that
453+ # support spec decode (FLASH_ATTN_MLA with varlen support, FLASHMLA with
454+ # uniform support) will use the decode pipeline (MQA-style), while
455+ # backends that only support single-token queries will use the prefill
456+ # pipeline (MHA-style). This ensures the reference implementation
457+ # matches each backend's actual decode/prefill pipeline path.
399458 is_decode = []
400- for i , backend in enumerate (BACKENDS_TO_TEST ):
459+ for backend_idx , backend in enumerate (BACKENDS_TO_TEST ):
401460 builder_cls , _ = try_get_attention_backend (backend )
402- is_decode .append (q_len <= builder_cls .reorder_batch_threshold )
461+ if is_spec_decode_test :
462+ query_len_support = getattr (
463+ builder_cls , "query_len_support" , QueryLenSupport .SINGLE_ONLY
464+ )
465+ supports_spec = query_len_support != QueryLenSupport .SINGLE_ONLY
466+ is_decode .append (supports_spec )
467+ else :
468+ threshold = getattr (builder_cls , "reorder_batch_threshold" , None )
469+ query_len_support = getattr (
470+ builder_cls , "query_len_support" , QueryLenSupport .SINGLE_ONLY
471+ )
472+ within_threshold = q_len <= threshold if threshold else False
473+ if (
474+ within_threshold
475+ and query_len_support == QueryLenSupport .UNIFORM
476+ and i > 0
477+ ):
478+ first_q_len = query_lens [0 ]
479+ within_threshold = q_len == first_q_len
480+ is_decode .append (within_threshold )
403481
404482 # Split q into nope and rope components
405483 q_nope , q_pe = q_c .split ([qk_nope_head_dim , qk_rope_head_dim ], dim = - 1 )
@@ -478,11 +556,11 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
478556 sdpa_out_i_prefill = sdpa_out_i_prefill .transpose (1 , 2 ).squeeze (0 )
479557 sdpa_out_i_prefill = sdpa_out_i_prefill .flatten (start_dim = - 2 )
480558
481- for i , backend in enumerate (BACKENDS_TO_TEST ):
482- if is_decode [i ]:
483- all_sdpa_outputs [i ].append (sdpa_out_i_decode )
559+ for backend_idx , backend in enumerate (BACKENDS_TO_TEST ):
560+ if is_decode [backend_idx ]:
561+ all_sdpa_outputs [backend_idx ].append (sdpa_out_i_decode )
484562 else :
485- all_sdpa_outputs [i ].append (sdpa_out_i_prefill )
563+ all_sdpa_outputs [backend_idx ].append (sdpa_out_i_prefill )
486564
487565 # Inputs for vLLM MLA backends are just the new tokens
488566 all_q_vllm .append (q_c )
@@ -497,9 +575,9 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
497575 query_vllm = torch .cat (all_q_vllm , dim = 0 )
498576 kv_c_vllm = torch .cat (all_kv_c_vllm , dim = 0 )
499577 k_pe_vllm = torch .cat (all_k_pe_vllm , dim = 0 )
500- sdpa_outputs = []
501- for i , backend in enumerate (BACKENDS_TO_TEST ):
502- sdpa_outputs . append ( torch .cat (all_sdpa_outputs [i ], dim = 0 ) )
578+ sdpa_outputs = {}
579+ for backend_idx , backend in enumerate (BACKENDS_TO_TEST ):
580+ sdpa_outputs [ backend ] = torch .cat (all_sdpa_outputs [backend_idx ], dim = 0 )
503581
504582 # Create mock kv_b_proj using the same weights as reference implementation
505583 from vllm .model_executor .layers .linear import ColumnParallelLinear
@@ -516,7 +594,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
516594 kv_b_proj_weight = kv_b_proj_weight .view (
517595 kv_lora_rank , num_q_heads * (qk_nope_head_dim + v_head_dim )
518596 )
519- mock_kv_b_proj .weight = torch .nn .Parameter (kv_b_proj_weight .T )
597+ mock_kv_b_proj .weight = torch .nn .Parameter (kv_b_proj_weight .T , requires_grad = False )
520598
521599 # Create metadata using original batch spec
522600 common_attn_metadata = create_common_attn_metadata (
@@ -537,7 +615,11 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
537615 )
538616
539617 # 4. Run vLLM backends and compare
540- for i , backend_name in enumerate (BACKENDS_TO_TEST ):
618+ for backend_idx , backend_name in enumerate (BACKENDS_TO_TEST ):
619+ # Skip backends that don't support spec decode for spec decode tests
620+ if is_spec_decode_test and backend_name not in spec_decode_backends :
621+ continue
622+
541623 backend_output = run_attention_backend (
542624 backend_name ,
543625 kv_cache_spec ,
@@ -556,14 +638,17 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
556638 mock_kv_b_proj ,
557639 )
558640
641+ # Use backend_idx to get the correct SDPA output for this backend
642+ expected_output = sdpa_outputs [backend_name ]
643+
559644 # Check shape and dtype consistency
560- assert backend_output .shape == sdpa_outputs [ i ] .shape , (
645+ assert backend_output .shape == expected_output .shape , (
561646 f"[{ backend_name } ] shape { backend_output .shape } != "
562- f"SDPA shape { sdpa_outputs [ i ] .shape } "
647+ f"SDPA shape { expected_output .shape } "
563648 )
564- assert backend_output .dtype == sdpa_outputs [ i ] .dtype , (
649+ assert backend_output .dtype == expected_output .dtype , (
565650 f"[{ backend_name } ] dtype { backend_output .dtype } != "
566- f"SDPA dtype { sdpa_outputs [ i ] .dtype } "
651+ f"SDPA dtype { expected_output .dtype } "
567652 )
568653
569654 assert torch .isfinite (backend_output ).all (), (
@@ -574,12 +659,12 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
574659 rtol = 1e-2
575660 atol = 5e-1
576661
577- max_diff = torch .max (torch .abs (backend_output - sdpa_outputs [ i ] )).item ()
662+ max_diff = torch .max (torch .abs (backend_output - expected_output )).item ()
578663 max_rel_diff = torch .max (
579- torch .abs (backend_output - sdpa_outputs [ i ] ) / torch .abs (sdpa_outputs [ i ] )
664+ torch .abs (backend_output - expected_output ) / torch .abs (expected_output )
580665 ).item ()
581666 all_close = torch .allclose (
582- backend_output , sdpa_outputs [ i ] , rtol = rtol , atol = atol
667+ backend_output , expected_output , rtol = rtol , atol = atol
583668 )
584669
585670 assert all_close , (
0 commit comments