@@ -112,6 +112,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
112112 parallel_config = vllm_config .parallel_config
113113 scheduler_config = vllm_config .scheduler_config
114114 compilation_config = vllm_config .compilation_config
115+ model_config = vllm_config .model_config
115116
116117 if parallel_config .worker_cls == "auto" :
117118 if scheduler_config .is_multi_step :
@@ -142,14 +143,21 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
142143 cache_config = vllm_config .cache_config
143144 if cache_config and cache_config .block_size is None :
144145 cache_config .block_size = 16
146+
145147 # TODO(lucas): handle this more gracefully
146- if envs .VLLM_ATTENTION_BACKEND is not None \
147- and envs .VLLM_ATTENTION_BACKEND == "FLASHMLA" \
148- and cache_config .block_size != 64 :
149- cache_config .block_size = 64
150- logger .info (
151- "FlashMLA: Forcing kv cache block size to 64 since this"
152- " is currently the only block size supported by the kernel." )
148+ # Note: model_config may be None during testing
149+ if model_config is not None and model_config .use_mla :
150+ # if `VLLM_ATTENTION_BACKEND` is not set and we are using MLA, then
151+ # we default to FlashMLA backend, so we need to force the blocksize
152+ # here
153+ use_flashmla = (envs .VLLM_ATTENTION_BACKEND is None \
154+ or envs .VLLM_ATTENTION_BACKEND == "FLASHMLA" )
155+ from vllm .attention .backends .flashmla import is_flashmla_supported
156+ if use_flashmla and is_flashmla_supported ()[0 ] \
157+ and cache_config .block_size != 64 :
158+ cache_config .block_size = 64
159+ logger .info (
160+ "Forcing kv cache block size to 64 for FlashMLA backend." )
153161
154162 if (parallel_config .data_parallel_size > 1
155163 and compilation_config .use_cudagraph ):
@@ -173,7 +181,15 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
173181 if use_mla :
174182 # TODO(lucas): refactor to be more concise
175183 # we should probably consider factoring out V1 here
176- if selected_backend == _Backend .FLASHMLA :
184+ if selected_backend == _Backend .TRITON_MLA or block_size != 64 :
185+ if use_v1 :
186+ logger .info_once ("Using Triton MLA backend on V1 engine." )
187+ return ("vllm.v1.attention.backends.mla."
188+ "triton_mla.TritonMLABackend" )
189+ else :
190+ logger .info ("Using Triton MLA backend." )
191+ return "vllm.attention.backends.triton_mla.TritonMLABackend"
192+ else :
177193 from vllm .attention .backends .flashmla import (
178194 is_flashmla_supported )
179195 if not is_flashmla_supported ()[0 ]:
@@ -195,14 +211,6 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
195211 logger .info ("Using FlashMLA backend." )
196212 return ("vllm.attention.backends."
197213 "flashmla.FlashMLABackend" )
198-
199- if use_v1 :
200- logger .info_once ("Using Triton MLA backend on V1 engine." )
201- return ("vllm.v1.attention.backends.mla."
202- "triton_mla.TritonMLABackend" )
203- else :
204- logger .info ("Using Triton MLA backend." )
205- return "vllm.attention.backends.triton_mla.TritonMLABackend"
206214 if use_v1 :
207215 logger .info_once ("Using Flash Attention backend on V1 engine." )
208216 return ("vllm.v1.attention.backends.flash_attn."
0 commit comments