Skip to content

Commit 482ef90

Browse files
MatthewBonannidevpatelio
authored andcommitted
[Attention] Fix FlashMLA metadata builder arguments for q_len > 1 (vllm-project#27368)
Signed-off-by: Matthew Bonanni <[email protected]>
1 parent 8556272 commit 482ef90

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

vllm/v1/attention/backends/mla/flashmla.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,13 @@ def _build_decode(
120120
num_decode_tokens: int,
121121
dcp_tot_seq_lens_device: torch.Tensor | None,
122122
) -> FlashMLADecodeMetadata:
123+
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
124+
# we use the max but all should be the same due to uniform length requirement
125+
max_query_len = query_lens_cpu.max().item()
126+
num_q_tokens_per_head_k = max_query_len * self.num_q_heads // 1
123127
tile_scheduler_metadata, num_splits = get_mla_metadata(
124128
seq_lens_device,
125-
self.num_q_heads,
129+
num_q_tokens_per_head_k,
126130
1, # MQA for the decode path
127131
is_fp8_kvcache=self.is_fp8_kvcache,
128132
)

0 commit comments

Comments
 (0)