Skip to content
27 changes: 23 additions & 4 deletions vllm/model_executor/layers/attention/mla_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,18 +440,34 @@ def __post_init__(self):
A = TypeVar("A", bound=AttentionMetadata)


def is_deepseek_r1_mla_compatible(vllm_config: VllmConfig) -> bool:
# Check if model has DeepSeek R1 compatible MLA dimensions:
# qk_nope_head_dim = 128, qk_rope_head_dim = 64, v_head_dim = 128
# which results in query/key head dim = 192.
if vllm_config.model_config is None:
return False
hf_text_config = vllm_config.model_config.hf_text_config
qk_nope_head_dim = getattr(hf_text_config, "qk_nope_head_dim", 1)
qk_rope_head_dim = getattr(hf_text_config, "qk_rope_head_dim", 1)
v_head_dim = getattr(hf_text_config, "v_head_dim", 1)
return qk_nope_head_dim == 128 and qk_rope_head_dim == 64 and v_head_dim == 128


def use_flashinfer_prefill() -> bool:
# For blackwell default to flashinfer prefill if it's available since
# it is faster than FA2.
from vllm.config import get_current_vllm_config

vllm_config = get_current_vllm_config()
return (
if not (
not vllm_config.attention_config.disable_flashinfer_prefill
and flashinfer_available
and not vllm_config.attention_config.use_cudnn_prefill
and current_platform.is_device_capability_family(100)
)
):
return False

return is_deepseek_r1_mla_compatible(vllm_config)


def use_cudnn_prefill() -> bool:
Expand All @@ -471,11 +487,14 @@ def use_trtllm_ragged_deepseek_prefill() -> bool:
from vllm.config import get_current_vllm_config

vllm_config = get_current_vllm_config()
return (
if not (
flashinfer_available
and vllm_config.attention_config.use_trtllm_ragged_deepseek_prefill
and current_platform.is_device_capability_family(100)
)
):
return False

return is_deepseek_r1_mla_compatible(vllm_config)


@dataclass
Expand Down
25 changes: 19 additions & 6 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,28 +180,41 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
use_cutlass_mla = False
use_flashinfer_mla = False

from vllm.v1.attention.ops.flashmla import is_flashmla_dense_supported

if vllm_config.attention_config.backend is None:
# Default case
if cls.is_device_capability_family(100) and not use_sparse:
# Blackwell => Force FlashInfer MLA (unless sparse, i.e. DSv3.2).
hf_text_config = model_config.hf_text_config
qk_nope_head_dim = getattr(hf_text_config, "qk_nope_head_dim", 1)
if (
cls.is_device_capability_family(100)
and not use_sparse
and qk_nope_head_dim == 128
):
# Blackwell => Force FlashInfer MLA (unless sparse, i.e. DSv3.2)
# and only if qk_nope_head_dim == 128 (kernel constraint)
use_flashinfer_mla = True
# Set the backend in AttentionConfig so it's used during
# backend selection
vllm_config.attention_config.backend = (
AttentionBackendEnum.FLASHINFER_MLA
)
else:
# Not Blackwell
elif cls.is_device_capability_family(100) and not use_sparse:
# Fall back to CUTLASS_MLA as 2nd priority on Blackwell
use_cutlass_mla = True
elif is_flashmla_dense_supported()[0]:
# Non-Blackwell with FlashMLA support
use_flashmla = True
else:
# Fallback: will use Triton MLA or other compatible backend
pass
else:
# Forced case
backend = vllm_config.attention_config.backend
use_flashmla = backend == AttentionBackendEnum.FLASHMLA
use_cutlass_mla = backend == AttentionBackendEnum.CUTLASS_MLA
use_flashinfer_mla = backend == AttentionBackendEnum.FLASHINFER_MLA

from vllm.v1.attention.ops.flashmla import is_flashmla_dense_supported

if (
use_flashmla
and is_flashmla_dense_supported()[0]
Expand Down
2 changes: 2 additions & 0 deletions vllm/transformers_utils/model_arch_config_convertor.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,8 @@ def is_deepseek_mla(self) -> bool:
"deepseek_v3",
"deepseek_v32",
"deepseek_mtp",
"glm4_moe_lite",
"glm4_moe_lite_mtp",
"kimi_k2",
"kimi_linear",
"longcat_flash",
Expand Down
26 changes: 26 additions & 0 deletions vllm/v1/attention/backends/mla/flashinfer_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,32 @@ def get_builder_cls() -> type["FlashInferMLAMetadataBuilder"]:
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
return capability.major == 10

@classmethod
def supports_combination(
cls,
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: CacheDType | None,
block_size: int,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
device_capability: DeviceCapability,
) -> str | None:
# FlashInfer MLA kernel requires qk_nope_head_dim == 128
from vllm.config import get_current_vllm_config

vllm_config = get_current_vllm_config()
if vllm_config.model_config is not None:
hf_text_config = vllm_config.model_config.hf_text_config
qk_nope_head_dim = getattr(hf_text_config, "qk_nope_head_dim", 1)
if qk_nope_head_dim != 128:
return (
f"FlashInfer MLA kernel requires qk_nope_head_dim == 128, "
f"but got {qk_nope_head_dim}"
)
return None

@classmethod
def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None":
return "HND"
Expand Down