We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 8556272 commit 482ef90Copy full SHA for 482ef90
vllm/v1/attention/backends/mla/flashmla.py
@@ -120,9 +120,13 @@ def _build_decode(
120
num_decode_tokens: int,
121
dcp_tot_seq_lens_device: torch.Tensor | None,
122
) -> FlashMLADecodeMetadata:
123
+ query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
124
+ # we use the max but all should be the same due to uniform length requirement
125
+ max_query_len = query_lens_cpu.max().item()
126
+ num_q_tokens_per_head_k = max_query_len * self.num_q_heads // 1
127
tile_scheduler_metadata, num_splits = get_mla_metadata(
128
seq_lens_device,
- self.num_q_heads,
129
+ num_q_tokens_per_head_k,
130
1, # MQA for the decode path
131
is_fp8_kvcache=self.is_fp8_kvcache,
132
)
0 commit comments