Skip to content

Commit f9b30df

Browse files
LucasWilkinsontlrmchlsmth
authored andcommitted
[Attention] Default to FlashMLA backend for MLA (vllm-project#14451)
Signed-off-by: Lucas Wilkinson <[email protected]> Signed-off-by: Tyler Michael Smith <[email protected]> Co-authored-by: Tyler Michael Smith <[email protected]> Signed-off-by: Louis Ulmer <[email protected]>
1 parent 68fe7a9 commit f9b30df

File tree

1 file changed

+24
-16
lines changed

1 file changed

+24
-16
lines changed

vllm/platforms/cuda.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
112112
parallel_config = vllm_config.parallel_config
113113
scheduler_config = vllm_config.scheduler_config
114114
compilation_config = vllm_config.compilation_config
115+
model_config = vllm_config.model_config
115116

116117
if parallel_config.worker_cls == "auto":
117118
if scheduler_config.is_multi_step:
@@ -142,14 +143,21 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
142143
cache_config = vllm_config.cache_config
143144
if cache_config and cache_config.block_size is None:
144145
cache_config.block_size = 16
146+
145147
# TODO(lucas): handle this more gracefully
146-
if envs.VLLM_ATTENTION_BACKEND is not None \
147-
and envs.VLLM_ATTENTION_BACKEND == "FLASHMLA" \
148-
and cache_config.block_size != 64:
149-
cache_config.block_size = 64
150-
logger.info(
151-
"FlashMLA: Forcing kv cache block size to 64 since this"
152-
" is currently the only block size supported by the kernel.")
148+
# Note: model_config may be None during testing
149+
if model_config is not None and model_config.use_mla:
150+
# if `VLLM_ATTENTION_BACKEND` is not set and we are using MLA, then
151+
# we default to FlashMLA backend, so we need to force the blocksize
152+
# here
153+
use_flashmla = (envs.VLLM_ATTENTION_BACKEND is None \
154+
or envs.VLLM_ATTENTION_BACKEND == "FLASHMLA")
155+
from vllm.attention.backends.flashmla import is_flashmla_supported
156+
if use_flashmla and is_flashmla_supported()[0] \
157+
and cache_config.block_size != 64:
158+
cache_config.block_size = 64
159+
logger.info(
160+
"Forcing kv cache block size to 64 for FlashMLA backend.")
153161

154162
if (parallel_config.data_parallel_size > 1
155163
and compilation_config.use_cudagraph):
@@ -173,7 +181,15 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
173181
if use_mla:
174182
# TODO(lucas): refactor to be more concise
175183
# we should probably consider factoring out V1 here
176-
if selected_backend == _Backend.FLASHMLA:
184+
if selected_backend == _Backend.TRITON_MLA or block_size != 64:
185+
if use_v1:
186+
logger.info_once("Using Triton MLA backend on V1 engine.")
187+
return ("vllm.v1.attention.backends.mla."
188+
"triton_mla.TritonMLABackend")
189+
else:
190+
logger.info("Using Triton MLA backend.")
191+
return "vllm.attention.backends.triton_mla.TritonMLABackend"
192+
else:
177193
from vllm.attention.backends.flashmla import (
178194
is_flashmla_supported)
179195
if not is_flashmla_supported()[0]:
@@ -195,14 +211,6 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
195211
logger.info("Using FlashMLA backend.")
196212
return ("vllm.attention.backends."
197213
"flashmla.FlashMLABackend")
198-
199-
if use_v1:
200-
logger.info_once("Using Triton MLA backend on V1 engine.")
201-
return ("vllm.v1.attention.backends.mla."
202-
"triton_mla.TritonMLABackend")
203-
else:
204-
logger.info("Using Triton MLA backend.")
205-
return "vllm.attention.backends.triton_mla.TritonMLABackend"
206214
if use_v1:
207215
logger.info_once("Using Flash Attention backend on V1 engine.")
208216
return ("vllm.v1.attention.backends.flash_attn."

0 commit comments

Comments
 (0)