Skip to content

Commit 7e4be74

Browse files
authored
[Bug] Batch invariant: Fix flash attn MLA RuntimeError: scheduler_metadata must have shape (metadata_size) (#27884)
1 parent 380ba68 commit 7e4be74

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

vllm/model_executor/layers/batch_invariant.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
import contextlib
4+
import functools
45
import os
56
from collections import namedtuple
67
from collections.abc import Callable
@@ -846,6 +847,7 @@ def get_batch_invariant_attention_block_size() -> AttentionBlockSize:
846847
return AttentionBlockSize(block_m=16, block_n=16)
847848

848849

850+
@functools.cache
849851
def vllm_is_batch_invariant():
850852
env_key = "VLLM_BATCH_INVARIANT"
851853
is_overridden = False

vllm/v1/attention/backends/mla/flashattn_mla.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,9 @@ def _build_decode(
163163
# we only set num_splits when using cuda graphs.
164164
max_num_splits = self.max_num_splits
165165

166+
if vllm_is_batch_invariant():
167+
max_num_splits = 1
168+
166169
scheduler_metadata = self._schedule_decode(
167170
num_reqs=seq_lens_cpu.numel(),
168171
cu_query_lens=query_start_loc_device,
@@ -188,9 +191,6 @@ def _build_decode(
188191
self.scheduler_metadata[n:] = 0
189192
scheduler_metadata = self.scheduler_metadata[:n]
190193

191-
if vllm_is_batch_invariant():
192-
max_num_splits = 1
193-
194194
metadata = FlashAttnMLADecodeMetadata(
195195
block_table=block_table_tensor,
196196
seq_lens=seq_lens_device,

0 commit comments

Comments
 (0)