File tree Expand file tree Collapse file tree 2 files changed +5
-3
lines changed
v1/attention/backends/mla Expand file tree Collapse file tree 2 files changed +5
-3
lines changed Original file line number Diff line number Diff line change 11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33import contextlib
4+ import functools
45import os
56from collections import namedtuple
67from collections .abc import Callable
@@ -846,6 +847,7 @@ def get_batch_invariant_attention_block_size() -> AttentionBlockSize:
846847 return AttentionBlockSize (block_m = 16 , block_n = 16 )
847848
848849
850+ @functools .cache
849851def vllm_is_batch_invariant ():
850852 env_key = "VLLM_BATCH_INVARIANT"
851853 is_overridden = False
Original file line number Diff line number Diff line change @@ -163,6 +163,9 @@ def _build_decode(
163163 # we only set num_splits when using cuda graphs.
164164 max_num_splits = self .max_num_splits
165165
166+ if vllm_is_batch_invariant ():
167+ max_num_splits = 1
168+
166169 scheduler_metadata = self ._schedule_decode (
167170 num_reqs = seq_lens_cpu .numel (),
168171 cu_query_lens = query_start_loc_device ,
@@ -188,9 +191,6 @@ def _build_decode(
188191 self .scheduler_metadata [n :] = 0
189192 scheduler_metadata = self .scheduler_metadata [:n ]
190193
191- if vllm_is_batch_invariant ():
192- max_num_splits = 1
193-
194194 metadata = FlashAttnMLADecodeMetadata (
195195 block_table = block_table_tensor ,
196196 seq_lens = seq_lens_device ,
You can’t perform that action at this time.
0 commit comments