From 3a58cfb9fe5cbaafdae4c8f9216ffc0552f3bfef Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Thu, 9 Oct 2025 17:11:20 +0000 Subject: [PATCH 01/13] first pass Signed-off-by: Matthew Bonanni --- tests/v1/attention/test_mla_backends.py | 30 +++++++++++++++++++++- vllm/v1/attention/backends/mla/flashmla.py | 20 +++++++++++++-- 2 files changed, 47 insertions(+), 3 deletions(-) diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py index 3b6a9115435c..e4b33e1dfef0 100644 --- a/tests/v1/attention/test_mla_backends.py +++ b/tests/v1/attention/test_mla_backends.py @@ -68,6 +68,12 @@ def _convert_dtype_to_torch(dtype): "large_prefill": BatchSpec(seq_lens=[4096] * 8, query_lens=[32] * 8), "single_decode": BatchSpec(seq_lens=[1024], query_lens=[1]), "single_prefill": BatchSpec(seq_lens=[1024], query_lens=[64]), + "spec_decode_small": BatchSpec( + seq_lens=[128, 256, 512, 1024], query_lens=[4, 4, 4, 4] + ), + "spec_decode_medium": BatchSpec( + seq_lens=[512, 1024, 2048, 512, 1024, 2048], query_lens=[8, 8, 8, 8, 8, 8] + ), } @@ -311,6 +317,8 @@ def run_attention_backend( "large_prefill", "single_decode", "single_prefill", + "spec_decode_small", + "spec_decode_medium", ], ) @pytest.mark.parametrize("model", ["deepseek-ai/DeepSeek-V2-Lite-Chat"]) @@ -331,6 +339,9 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): 5. Comparing the vLLM backend's output to the ground-truth SDPA output. """ batch_spec = BATCH_SPECS[batch_spec_name] + is_spec_decode_test = batch_spec_name.startswith("spec_decode") + spec_decode_backends = {_Backend.FLASH_ATTN_MLA, _Backend.FLASHMLA} + vllm_config = create_vllm_config( model_name=model, max_model_len=max(batch_spec.seq_lens), num_gpu_blocks=2048 ) @@ -398,10 +409,23 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): k_pe_full = torch.randn(s_len, 1, qk_rope_head_dim, dtype=dtype, device=device) # Determine if this is decode or prefill + # NOTE: For spec decode tests with uniform query_len > 1, backends that + # support uniform spec decode (FLASH_ATTN_MLA, FLASHMLA) will use the + # decode path (MQA-style), while others will use prefill path (MHA-style). + # This ensures the reference implementation matches each backend's actual path. is_decode = [] for i, backend in enumerate(BACKENDS_TO_TEST): builder_cls, _ = try_get_attention_backend(backend) - is_decode.append(q_len <= builder_cls.reorder_batch_threshold) + # For spec decode tests, check if backend supports uniform spec decode + if is_spec_decode_test: + supports_spec = getattr( + builder_cls, "supports_uniform_spec_as_decode", False + ) + is_decode.append(supports_spec) + else: + # For non-spec-decode tests, use the class-level threshold + threshold = getattr(builder_cls, "reorder_batch_threshold", None) + is_decode.append(q_len <= threshold if threshold else False) # Split q into nope and rope components q_nope, q_pe = q_c.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1) @@ -540,6 +564,10 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): # 4. Run vLLM backends and compare for i, backend_name in enumerate(BACKENDS_TO_TEST): + # Skip backends that don't support spec decode for spec decode tests + if is_spec_decode_test and backend_name not in spec_decode_backends: + continue + backend_output = run_attention_backend( backend_name, kv_cache_spec, diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index e0f4a7f0382b..0b3506f337fe 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -21,7 +21,11 @@ MLACommonMetadata, MLACommonMetadataBuilder, ) -from vllm.v1.attention.backends.utils import AttentionCGSupport +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, + reshape_attn_output_for_spec_decode, + reshape_query_for_spec_decode, +) from vllm.v1.kv_cache_interface import AttentionSpec logger = init_logger(__name__) @@ -62,6 +66,7 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]): class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH + supports_uniform_spec_as_decode: ClassVar[bool] = True def __init__( self, @@ -98,6 +103,11 @@ def __init__( dtype=torch.int32, ) + supports_spec_as_decode = self.supports_uniform_spec_as_decode + self._init_reorder_batch_threshold( + self.reorder_batch_threshold, supports_spec_as_decode + ) + def _build_decode( self, block_table_tensor: torch.Tensor, @@ -216,8 +226,12 @@ def _forward_decode( q = torch.cat(q, dim=-1) assert isinstance(q, torch.Tensor) + + num_decodes = attn_metadata.num_decodes + q = reshape_query_for_spec_decode(q, num_decodes) + o, lse = flash_mla_with_kvcache( - q=q.unsqueeze(1), # Add seqlen dim of 1 (decode) + q=q, k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 block_table=attn_metadata.decode.block_table, cache_seqlens=attn_metadata.decode.seq_lens, @@ -230,4 +244,6 @@ def _forward_decode( descale_k=layer._k_scale.reshape(1), ) + o = reshape_attn_output_for_spec_decode(o) + return o, lse From 79bb4c0d28afa56a32b2c9b6b01fea5c58ee3b0c Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Thu, 9 Oct 2025 17:35:31 +0000 Subject: [PATCH 02/13] fix test_mla_backends - vllm_config context and flashmla support Signed-off-by: Matthew Bonanni --- tests/v1/attention/test_mla_backends.py | 111 +++++++++++++----------- 1 file changed, 60 insertions(+), 51 deletions(-) diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py index e4b33e1dfef0..fea30d780adb 100644 --- a/tests/v1/attention/test_mla_backends.py +++ b/tests/v1/attention/test_mla_backends.py @@ -16,6 +16,8 @@ ) from vllm import _custom_ops as ops from vllm.attention.backends.registry import _Backend +from vllm.attention.ops.flashmla import is_flashmla_dense_supported +from vllm.config.vllm import set_current_vllm_config from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.kv_cache_interface import FullAttentionSpec @@ -31,6 +33,10 @@ if not torch.cuda.is_available() or torch.cuda.get_device_properties(0).major < 10: BACKENDS_TO_TEST.remove(_Backend.CUTLASS_MLA) +# Remove FLASHMLA from the list if not supported +if not is_flashmla_dense_supported()[0]: + BACKENDS_TO_TEST.remove(_Backend.FLASHMLA) + torch.manual_seed(42) @@ -247,61 +253,64 @@ def run_attention_backend( builder_cls, impl_cls = try_get_attention_backend(backend) - # Build metadata - builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device) - attn_metadata = builder.build( - common_prefix_len=0, - common_attn_metadata=common_attn_metadata, - ) + # Set the current vllm config so that get_current_vllm_config() works + # in the backend implementations + with set_current_vllm_config(vllm_config): + # Build metadata + builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device) + attn_metadata = builder.build( + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + ) - # Instantiate MLA implementation - num_heads = vllm_config.model_config.get_num_attention_heads( - vllm_config.parallel_config - ) - num_kv_heads = vllm_config.model_config.get_num_kv_heads( - vllm_config.parallel_config - ) - head_size = vllm_config.model_config.get_head_size() - scale = 1.0 / (head_size**0.5) - impl = impl_cls( - num_heads=num_heads, - head_size=head_size, - scale=scale, - num_kv_heads=num_kv_heads, - alibi_slopes=None, - sliding_window=None, - kv_cache_dtype="auto", - logits_soft_cap=None, - attn_type="decoder", - kv_sharing_target_layer_name=None, - q_lora_rank=None, - kv_lora_rank=kv_lora_rank, - qk_nope_head_dim=qk_nope_head_dim, - qk_rope_head_dim=qk_rope_head_dim, - qk_head_dim=qk_nope_head_dim + qk_rope_head_dim, - v_head_dim=v_head_dim, - kv_b_proj=mock_kv_b_proj, - ) + # Instantiate MLA implementation + num_heads = vllm_config.model_config.get_num_attention_heads( + vllm_config.parallel_config + ) + num_kv_heads = vllm_config.model_config.get_num_kv_heads( + vllm_config.parallel_config + ) + head_size = vllm_config.model_config.get_head_size() + scale = 1.0 / (head_size**0.5) + impl = impl_cls( + num_heads=num_heads, + head_size=head_size, + scale=scale, + num_kv_heads=num_kv_heads, + alibi_slopes=None, + sliding_window=None, + kv_cache_dtype="auto", + logits_soft_cap=None, + attn_type="decoder", + kv_sharing_target_layer_name=None, + q_lora_rank=None, + kv_lora_rank=kv_lora_rank, + qk_nope_head_dim=qk_nope_head_dim, + qk_rope_head_dim=qk_rope_head_dim, + qk_head_dim=qk_nope_head_dim + qk_rope_head_dim, + v_head_dim=v_head_dim, + kv_b_proj=mock_kv_b_proj, + ) - # Process weights to create W_UK_T and W_UV attributes needed by MLA - act_dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype) - impl.process_weights_after_loading(act_dtype) + # Process weights to create W_UK_T and W_UV attributes needed by MLA + act_dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype) + impl.process_weights_after_loading(act_dtype) - # Create mock layer and output buffer - mock_layer = MockAttentionLayer(device) - num_tokens = query.shape[0] - output = torch.empty( - num_tokens, num_heads * v_head_dim, dtype=query.dtype, device=query.device - ) + # Create mock layer and output buffer + mock_layer = MockAttentionLayer(device) + num_tokens = query.shape[0] + output = torch.empty( + num_tokens, num_heads * v_head_dim, dtype=query.dtype, device=query.device + ) - # Run forward pass - # NOTE: The query, key, and value are already shaped correctly - # in the calling test function. - output = impl.forward( - mock_layer, query, kv_c, k_pe, kv_cache, attn_metadata, output=output - ) + # Run forward pass + # NOTE: The query, key, and value are already shaped correctly + # in the calling test function. + output = impl.forward( + mock_layer, query, kv_c, k_pe, kv_cache, attn_metadata, output=output + ) - return output + return output @pytest.mark.parametrize( @@ -542,7 +551,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): kv_b_proj_weight = kv_b_proj_weight.view( kv_lora_rank, num_q_heads * (qk_nope_head_dim + v_head_dim) ) - mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T) + mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T, requires_grad=False) # Create metadata using original batch spec common_attn_metadata = create_common_attn_metadata( From a76ad7968f910a3e8c6d3c9e0080db940dc88c79 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Thu, 9 Oct 2025 17:56:01 +0000 Subject: [PATCH 03/13] fix allocated block count Signed-off-by: Matthew Bonanni --- tests/v1/attention/test_mla_backends.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py index fea30d780adb..eb2f7679ae29 100644 --- a/tests/v1/attention/test_mla_backends.py +++ b/tests/v1/attention/test_mla_backends.py @@ -351,8 +351,18 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): is_spec_decode_test = batch_spec_name.startswith("spec_decode") spec_decode_backends = {_Backend.FLASH_ATTN_MLA, _Backend.FLASHMLA} + block_size = 16 + required_blocks = sum( + (seq_len + block_size - 1) // block_size for seq_len in batch_spec.seq_lens + ) + # Add 1 for null block at index 0, and some buffer + num_gpu_blocks = required_blocks + 1 + 100 + vllm_config = create_vllm_config( - model_name=model, max_model_len=max(batch_spec.seq_lens), num_gpu_blocks=2048 + model_name=model, + max_model_len=max(batch_spec.seq_lens), + num_gpu_blocks=num_gpu_blocks, + block_size=block_size, ) device = torch.device("cuda:0") From 3115b2d605dacea5c43daa9e554d60adf4aaa59e Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Thu, 9 Oct 2025 20:34:40 +0000 Subject: [PATCH 04/13] fix index collision Signed-off-by: Matthew Bonanni --- tests/v1/attention/test_mla_backends.py | 40 +++++++++++++++---------- 1 file changed, 25 insertions(+), 15 deletions(-) diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py index eb2f7679ae29..dfd383762144 100644 --- a/tests/v1/attention/test_mla_backends.py +++ b/tests/v1/attention/test_mla_backends.py @@ -347,6 +347,13 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): simulated paged KV cache. 5. Comparing the vLLM backend's output to the ground-truth SDPA output. """ + # Reset random seeds to ensure deterministic behavior across test runs + torch.manual_seed(42) + torch.cuda.manual_seed_all(42) + + # Clear CUDA cache to avoid reusing stale memory from previous tests + torch.cuda.empty_cache() + batch_spec = BATCH_SPECS[batch_spec_name] is_spec_decode_test = batch_spec_name.startswith("spec_decode") spec_decode_backends = {_Backend.FLASH_ATTN_MLA, _Backend.FLASHMLA} @@ -433,7 +440,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): # decode path (MQA-style), while others will use prefill path (MHA-style). # This ensures the reference implementation matches each backend's actual path. is_decode = [] - for i, backend in enumerate(BACKENDS_TO_TEST): + for backend_idx, backend in enumerate(BACKENDS_TO_TEST): builder_cls, _ = try_get_attention_backend(backend) # For spec decode tests, check if backend supports uniform spec decode if is_spec_decode_test: @@ -523,11 +530,11 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): sdpa_out_i_prefill = sdpa_out_i_prefill.transpose(1, 2).squeeze(0) sdpa_out_i_prefill = sdpa_out_i_prefill.flatten(start_dim=-2) - for i, backend in enumerate(BACKENDS_TO_TEST): - if is_decode[i]: - all_sdpa_outputs[i].append(sdpa_out_i_decode) + for backend_idx, backend in enumerate(BACKENDS_TO_TEST): + if is_decode[backend_idx]: + all_sdpa_outputs[backend_idx].append(sdpa_out_i_decode) else: - all_sdpa_outputs[i].append(sdpa_out_i_prefill) + all_sdpa_outputs[backend_idx].append(sdpa_out_i_prefill) # Inputs for vLLM MLA backends are just the new tokens all_q_vllm.append(q_c) @@ -543,8 +550,8 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): kv_c_vllm = torch.cat(all_kv_c_vllm, dim=0) k_pe_vllm = torch.cat(all_k_pe_vllm, dim=0) sdpa_outputs = [] - for i, backend in enumerate(BACKENDS_TO_TEST): - sdpa_outputs.append(torch.cat(all_sdpa_outputs[i], dim=0)) + for backend_idx, backend in enumerate(BACKENDS_TO_TEST): + sdpa_outputs.append(torch.cat(all_sdpa_outputs[backend_idx], dim=0)) # Create mock kv_b_proj using the same weights as reference implementation from vllm.model_executor.layers.linear import ColumnParallelLinear @@ -582,7 +589,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): ) # 4. Run vLLM backends and compare - for i, backend_name in enumerate(BACKENDS_TO_TEST): + for backend_idx, backend_name in enumerate(BACKENDS_TO_TEST): # Skip backends that don't support spec decode for spec decode tests if is_spec_decode_test and backend_name not in spec_decode_backends: continue @@ -606,13 +613,13 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): ) # Check shape and dtype consistency - assert backend_output.shape == sdpa_outputs[i].shape, ( + assert backend_output.shape == sdpa_outputs[backend_idx].shape, ( f"[{backend_name}] shape {backend_output.shape} != " - f"SDPA shape {sdpa_outputs[i].shape}" + f"SDPA shape {sdpa_outputs[backend_idx].shape}" ) - assert backend_output.dtype == sdpa_outputs[i].dtype, ( + assert backend_output.dtype == sdpa_outputs[backend_idx].dtype, ( f"[{backend_name}] dtype {backend_output.dtype} != " - f"SDPA dtype {sdpa_outputs[i].dtype}" + f"SDPA dtype {sdpa_outputs[backend_idx].dtype}" ) assert torch.isfinite(backend_output).all(), ( @@ -623,12 +630,15 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): rtol = 1e-2 atol = 5e-1 - max_diff = torch.max(torch.abs(backend_output - sdpa_outputs[i])).item() + max_diff = torch.max( + torch.abs(backend_output - sdpa_outputs[backend_idx]) + ).item() max_rel_diff = torch.max( - torch.abs(backend_output - sdpa_outputs[i]) / torch.abs(sdpa_outputs[i]) + torch.abs(backend_output - sdpa_outputs[backend_idx]) + / torch.abs(sdpa_outputs[backend_idx]) ).item() all_close = torch.allclose( - backend_output, sdpa_outputs[i], rtol=rtol, atol=atol + backend_output, sdpa_outputs[backend_idx], rtol=rtol, atol=atol ) assert all_close, ( From e3edffe2f1a6471b7f18ce988dd9b5a42d0a63b6 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Thu, 9 Oct 2025 21:21:53 +0000 Subject: [PATCH 05/13] add note about flakiness Signed-off-by: Matthew Bonanni --- tests/v1/attention/test_mla_backends.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py index dfd383762144..285df0302163 100644 --- a/tests/v1/attention/test_mla_backends.py +++ b/tests/v1/attention/test_mla_backends.py @@ -1,6 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Tests for v1 MLA backends without GPUModelRunner dependency.""" +"""Tests for v1 MLA backends without GPUModelRunner dependency. + +Known Issues: +- FLASH_ATTN_MLA backend produces NaN values in test_backend_correctness[mixed_small] + when run after test_backend_correctness[small_prefill], but passes when run alone. +""" from typing import Optional, Union @@ -347,13 +352,6 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): simulated paged KV cache. 5. Comparing the vLLM backend's output to the ground-truth SDPA output. """ - # Reset random seeds to ensure deterministic behavior across test runs - torch.manual_seed(42) - torch.cuda.manual_seed_all(42) - - # Clear CUDA cache to avoid reusing stale memory from previous tests - torch.cuda.empty_cache() - batch_spec = BATCH_SPECS[batch_spec_name] is_spec_decode_test = batch_spec_name.startswith("spec_decode") spec_decode_backends = {_Backend.FLASH_ATTN_MLA, _Backend.FLASHMLA} From 6ccfaa3273052be057ae03734de4e83aa35d41de Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Thu, 9 Oct 2025 21:48:31 +0000 Subject: [PATCH 06/13] mark flashattention mla as supporting spec decode Signed-off-by: Matthew Bonanni --- vllm/v1/attention/backends/mla/flashattn_mla.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index c043990ffcc6..a6b7661a3dd1 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -66,6 +66,7 @@ class FlashAttnMLAMetadata(MLACommonMetadata[FlashAttnMLADecodeMetadata]): class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]): cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH + supports_uniform_spec_as_decode: ClassVar[bool] = True reorder_batch_threshold: int = 512 From e8269dab22c3ab75c9ee0ca7d177187f0d5f8f54 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Thu, 9 Oct 2025 22:09:50 +0000 Subject: [PATCH 07/13] add speculative config Signed-off-by: Matthew Bonanni --- tests/v1/attention/test_mla_backends.py | 38 +++++++++++++++++-------- 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py index 285df0302163..6d20b2c10b40 100644 --- a/tests/v1/attention/test_mla_backends.py +++ b/tests/v1/attention/test_mla_backends.py @@ -369,6 +369,20 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): num_gpu_blocks=num_gpu_blocks, block_size=block_size, ) + + # For spec decode tests, add a speculative_config to set the reorder_batch_threshold + if is_spec_decode_test: + from vllm.config import SpeculativeConfig + + # Get the query length from the batch spec (they should all be uniform) + query_len = batch_spec.query_lens[0] + # Set num_speculative_tokens to query_len - 1 + # (since threshold is 1 + num_spec_tokens) + # Use ngram method which doesn't require a draft model + vllm_config.speculative_config = SpeculativeConfig( + method="ngram", num_speculative_tokens=query_len - 1 + ) + device = torch.device("cuda:0") kv_cache_spec = create_standard_kv_cache_spec(vllm_config) @@ -547,9 +561,9 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): query_vllm = torch.cat(all_q_vllm, dim=0) kv_c_vllm = torch.cat(all_kv_c_vllm, dim=0) k_pe_vllm = torch.cat(all_k_pe_vllm, dim=0) - sdpa_outputs = [] + sdpa_outputs = {} for backend_idx, backend in enumerate(BACKENDS_TO_TEST): - sdpa_outputs.append(torch.cat(all_sdpa_outputs[backend_idx], dim=0)) + sdpa_outputs[backend] = torch.cat(all_sdpa_outputs[backend_idx], dim=0) # Create mock kv_b_proj using the same weights as reference implementation from vllm.model_executor.layers.linear import ColumnParallelLinear @@ -610,14 +624,17 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): mock_kv_b_proj, ) + # Use backend_idx to get the correct SDPA output for this backend + expected_output = sdpa_outputs[backend_name] + # Check shape and dtype consistency - assert backend_output.shape == sdpa_outputs[backend_idx].shape, ( + assert backend_output.shape == expected_output.shape, ( f"[{backend_name}] shape {backend_output.shape} != " - f"SDPA shape {sdpa_outputs[backend_idx].shape}" + f"SDPA shape {expected_output.shape}" ) - assert backend_output.dtype == sdpa_outputs[backend_idx].dtype, ( + assert backend_output.dtype == expected_output.dtype, ( f"[{backend_name}] dtype {backend_output.dtype} != " - f"SDPA dtype {sdpa_outputs[backend_idx].dtype}" + f"SDPA dtype {expected_output.dtype}" ) assert torch.isfinite(backend_output).all(), ( @@ -628,15 +645,12 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): rtol = 1e-2 atol = 5e-1 - max_diff = torch.max( - torch.abs(backend_output - sdpa_outputs[backend_idx]) - ).item() + max_diff = torch.max(torch.abs(backend_output - expected_output)).item() max_rel_diff = torch.max( - torch.abs(backend_output - sdpa_outputs[backend_idx]) - / torch.abs(sdpa_outputs[backend_idx]) + torch.abs(backend_output - expected_output) / torch.abs(expected_output) ).item() all_close = torch.allclose( - backend_output, sdpa_outputs[backend_idx], rtol=rtol, atol=atol + backend_output, expected_output, rtol=rtol, atol=atol ) assert all_close, ( From b8303e969b4ffed7e84262a645443a9ff775b476 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Thu, 9 Oct 2025 22:28:02 +0000 Subject: [PATCH 08/13] implement QueryLenSupport, tests pass Signed-off-by: Matthew Bonanni --- tests/v1/attention/test_mla_backends.py | 35 +++++++++---- vllm/v1/attention/backends/mla/common.py | 52 +++++++++++++++---- .../attention/backends/mla/flashattn_mla.py | 4 +- vllm/v1/attention/backends/mla/flashmla.py | 9 ++-- 4 files changed, 72 insertions(+), 28 deletions(-) diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py index 6d20b2c10b40..002db840103e 100644 --- a/tests/v1/attention/test_mla_backends.py +++ b/tests/v1/attention/test_mla_backends.py @@ -446,24 +446,41 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): # K_PE (rope component): [s_len, 1, qk_rope_head_dim] k_pe_full = torch.randn(s_len, 1, qk_rope_head_dim, dtype=dtype, device=device) - # Determine if this is decode or prefill + # Determine if this sequence uses the decode pipeline or prefill + # pipeline for each backend # NOTE: For spec decode tests with uniform query_len > 1, backends that - # support uniform spec decode (FLASH_ATTN_MLA, FLASHMLA) will use the - # decode path (MQA-style), while others will use prefill path (MHA-style). - # This ensures the reference implementation matches each backend's actual path. + # support spec decode (FLASH_ATTN_MLA with varlen support, FLASHMLA with + # uniform support) will use the decode pipeline (MQA-style), while + # backends that only support single-token queries will use the prefill + # pipeline (MHA-style). This ensures the reference implementation + # matches each backend's actual decode/prefill pipeline path. is_decode = [] for backend_idx, backend in enumerate(BACKENDS_TO_TEST): builder_cls, _ = try_get_attention_backend(backend) - # For spec decode tests, check if backend supports uniform spec decode if is_spec_decode_test: - supports_spec = getattr( - builder_cls, "supports_uniform_spec_as_decode", False + from vllm.v1.attention.backends.mla.common import QueryLenSupport + + query_len_support = getattr( + builder_cls, "query_len_support", QueryLenSupport.SINGLE_ONLY ) + supports_spec = query_len_support != QueryLenSupport.SINGLE_ONLY is_decode.append(supports_spec) else: - # For non-spec-decode tests, use the class-level threshold + from vllm.v1.attention.backends.mla.common import QueryLenSupport + threshold = getattr(builder_cls, "reorder_batch_threshold", None) - is_decode.append(q_len <= threshold if threshold else False) + query_len_support = getattr( + builder_cls, "query_len_support", QueryLenSupport.SINGLE_ONLY + ) + within_threshold = q_len <= threshold if threshold else False + if ( + within_threshold + and query_len_support == QueryLenSupport.UNIFORM + and i > 0 + ): + first_q_len = query_lens[0] + within_threshold = q_len == first_q_len + is_decode.append(within_threshold) # Split q into nope and rope components q_nope, q_pe = q_c.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index af396c2b4103..014436a21b86 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -190,6 +190,7 @@ import functools from abc import abstractmethod from dataclasses import dataclass, field +from enum import Enum from typing import ClassVar, Generic, Optional, TypeVar, Union import torch @@ -227,6 +228,24 @@ ) from vllm.v1.kv_cache_interface import AttentionSpec + +class QueryLenSupport(Enum): + """Defines the level of query length support for an attention backend's + decode pipeline. + + - SINGLE_ONLY: Decode pipeline only supports single-token queries + (query_len=1) + - UNIFORM: Decode pipeline supports uniform multi-token queries + (all requests must have same query_len > 1) + - VARLEN: Decode pipeline supports variable-length queries + (mixed query lengths in same batch) + """ + + SINGLE_ONLY = "single_only" + UNIFORM = "uniform" + VARLEN = "varlen" + + try: from vllm.vllm_flash_attn import flash_attn_varlen_func @@ -455,14 +474,13 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): understand this class """ - # Whether the backend supports reordering the batch such that - # short sequences (i.e. verification for speculative decoding) are - # classified as decode requests. - # If True, this will increase `reorder_batch_threshold` (below) when - # speculative decoding is enabled, and set `require_uniform=True` when - # when reordering the batch. Non-uniform decode requests will - # fall back to prefill in this case. - supports_uniform_spec_as_decode: ClassVar[bool] = False + # Defines the level of query length support for this backend. + # - SINGLE_ONLY: Only single-token queries (no spec decode support) + # - UNIFORM: Supports uniform multi-token queries (spec decode with uniform lengths) + # - VARLEN: Supports variable-length queries (spec decode with mixed lengths) + # If set to UNIFORM or VARLEN, this will increase `reorder_batch_threshold` when + # speculative decoding is enabled. + query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.SINGLE_ONLY # The threshold for reordering the batch into decode and prefill requests. # If > 1, the batch will be reordered such that requests with @@ -594,11 +612,23 @@ def __init__( device=device, ) - supports_spec_as_decode = self.supports_uniform_spec_as_decode + supports_spec_decode = self.query_len_support != QueryLenSupport.SINGLE_ONLY self._init_reorder_batch_threshold( - self.reorder_batch_threshold, supports_spec_as_decode + self.reorder_batch_threshold, supports_spec_decode ) + # Validate consistency between query_len_support and reorder_batch_threshold + if self.query_len_support == QueryLenSupport.SINGLE_ONLY: + assert self.reorder_batch_threshold == 1, ( + f"reorder_batch_threshold must be 1 when query_len_support is " + f"SINGLE_ONLY, got {self.reorder_batch_threshold}" + ) + else: + assert self.reorder_batch_threshold > 1, ( + f"reorder_batch_threshold must be > 1 when query_len_support " + f"is not SINGLE_ONLY, got {self.reorder_batch_threshold}" + ) + def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata): qo_indptr = prefill.query_start_loc @@ -740,7 +770,7 @@ def build( split_decodes_and_prefills( common_attn_metadata, decode_threshold=self.reorder_batch_threshold, - require_uniform=self.supports_uniform_spec_as_decode, + require_uniform=(self.query_len_support != QueryLenSupport.VARLEN), ) ) diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index a6b7661a3dd1..48f4a3c44f61 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -24,6 +24,7 @@ MLACommonImpl, MLACommonMetadata, MLACommonMetadataBuilder, + QueryLenSupport, ) from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.kv_cache_interface import AttentionSpec @@ -66,8 +67,7 @@ class FlashAttnMLAMetadata(MLACommonMetadata[FlashAttnMLADecodeMetadata]): class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]): cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH - supports_uniform_spec_as_decode: ClassVar[bool] = True - + query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.VARLEN reorder_batch_threshold: int = 512 def __init__( diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 0b3506f337fe..f37da8bc3d0c 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -20,6 +20,7 @@ MLACommonImpl, MLACommonMetadata, MLACommonMetadataBuilder, + QueryLenSupport, ) from vllm.v1.attention.backends.utils import ( AttentionCGSupport, @@ -66,7 +67,8 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]): class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH - supports_uniform_spec_as_decode: ClassVar[bool] = True + query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM + reorder_batch_threshold: int = 512 # TODO(matt): tune this def __init__( self, @@ -103,11 +105,6 @@ def __init__( dtype=torch.int32, ) - supports_spec_as_decode = self.supports_uniform_spec_as_decode - self._init_reorder_batch_threshold( - self.reorder_batch_threshold, supports_spec_as_decode - ) - def _build_decode( self, block_table_tensor: torch.Tensor, From fce2c63569b41241c5dba6b2b34eeeb197fc581e Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Thu, 9 Oct 2025 22:30:37 +0000 Subject: [PATCH 09/13] update comment Signed-off-by: Matthew Bonanni --- tests/v1/attention/test_mla_backends.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py index 002db840103e..c0da6f00e971 100644 --- a/tests/v1/attention/test_mla_backends.py +++ b/tests/v1/attention/test_mla_backends.py @@ -3,8 +3,9 @@ """Tests for v1 MLA backends without GPUModelRunner dependency. Known Issues: -- FLASH_ATTN_MLA backend produces NaN values in test_backend_correctness[mixed_small] - when run after test_backend_correctness[small_prefill], but passes when run alone. +- FLASH_ATTN_MLA backend occasionally produces NaN values in + test_backend_correctness[mixed_small] when run after + test_backend_correctness[small_prefill], but passes when run alone. """ from typing import Optional, Union From b11a56e93726033dc42c88b9b15b953bee22372e Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Thu, 9 Oct 2025 22:41:35 +0000 Subject: [PATCH 10/13] address comment Signed-off-by: Matthew Bonanni --- tests/v1/attention/test_mla_backends.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py index c0da6f00e971..23f26753f994 100644 --- a/tests/v1/attention/test_mla_backends.py +++ b/tests/v1/attention/test_mla_backends.py @@ -353,6 +353,8 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): simulated paged KV cache. 5. Comparing the vLLM backend's output to the ground-truth SDPA output. """ + from vllm.v1.attention.backends.mla.common import QueryLenSupport + batch_spec = BATCH_SPECS[batch_spec_name] is_spec_decode_test = batch_spec_name.startswith("spec_decode") spec_decode_backends = {_Backend.FLASH_ATTN_MLA, _Backend.FLASHMLA} @@ -459,16 +461,12 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): for backend_idx, backend in enumerate(BACKENDS_TO_TEST): builder_cls, _ = try_get_attention_backend(backend) if is_spec_decode_test: - from vllm.v1.attention.backends.mla.common import QueryLenSupport - query_len_support = getattr( builder_cls, "query_len_support", QueryLenSupport.SINGLE_ONLY ) supports_spec = query_len_support != QueryLenSupport.SINGLE_ONLY is_decode.append(supports_spec) else: - from vllm.v1.attention.backends.mla.common import QueryLenSupport - threshold = getattr(builder_cls, "reorder_batch_threshold", None) query_len_support = getattr( builder_cls, "query_len_support", QueryLenSupport.SINGLE_ONLY From 1eaab573139770669d1ae6752d168d09352b1d66 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Fri, 10 Oct 2025 16:10:23 +0000 Subject: [PATCH 11/13] add query_len_support to flashinfer_mla Signed-off-by: Matthew Bonanni --- vllm/v1/attention/backends/mla/common.py | 2 +- vllm/v1/attention/backends/mla/flashinfer_mla.py | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 014436a21b86..f288e58ef61a 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -485,7 +485,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): # The threshold for reordering the batch into decode and prefill requests. # If > 1, the batch will be reordered such that requests with # query length <= threshold are classified as decode requests. - # Use `supports_uniform_spec_as_decode` (above) to set this automatically + # Use `query_len_support` (above) to set this automatically # when speculative decoding is enabled. reorder_batch_threshold: int = 1 diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py index 206f96ea366a..5e01febe07c7 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py @@ -13,6 +13,7 @@ MLACommonImpl, MLACommonMetadata, MLACommonMetadataBuilder, + QueryLenSupport, ) from vllm.v1.attention.backends.utils import AttentionCGSupport @@ -22,11 +23,8 @@ class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]): - # enable spec-as-decode optimization - supports_uniform_spec_as_decode: ClassVar[bool] = True - - # enable full CUDA Graph support for decode-only capture cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH + query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM class FlashInferMLABackend(MLACommonBackend): From 1a7ca493b8d35350f6a0bdf5c9e7e56bcf8ac768 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Tue, 14 Oct 2025 12:59:07 -0400 Subject: [PATCH 12/13] remove assertion Signed-off-by: Matthew Bonanni --- vllm/v1/attention/backends/mla/common.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index a7fde82b229d..f7e6f12363ad 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -628,11 +628,6 @@ def __init__( f"reorder_batch_threshold must be 1 when query_len_support is " f"SINGLE_ONLY, got {self.reorder_batch_threshold}" ) - else: - assert self.reorder_batch_threshold > 1, ( - f"reorder_batch_threshold must be > 1 when query_len_support " - f"is not SINGLE_ONLY, got {self.reorder_batch_threshold}" - ) def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata): qo_indptr = prefill.query_start_loc From 1afd4e0612d90d133b889ff02796c72e70fdb0d3 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Tue, 14 Oct 2025 13:00:56 -0400 Subject: [PATCH 13/13] comments Signed-off-by: Matthew Bonanni --- vllm/v1/attention/backends/mla/flashattn_mla.py | 2 +- vllm/v1/attention/backends/mla/flashmla.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index cbb9ec7473bb..446f1c4f1f96 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -68,7 +68,7 @@ class FlashAttnMLAMetadata(MLACommonMetadata[FlashAttnMLADecodeMetadata]): class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]): cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.VARLEN - reorder_batch_threshold: int = 512 + reorder_batch_threshold: int = 512 # process small prefills with decode pathway def __init__( self, diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 3854ea818e4e..b15c09294c6b 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -68,7 +68,8 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]): class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM - reorder_batch_threshold: int = 512 # TODO(matt): tune this + reorder_batch_threshold: int = 512 # process small prefills with decode pathway + # ^ TODO(matt): tune this def __init__( self,