diff --git a/vllm/attention/backends/registry.py b/vllm/attention/backends/registry.py index 91e1cad01f4f..6747cf7743b1 100644 --- a/vllm/attention/backends/registry.py +++ b/vllm/attention/backends/registry.py @@ -52,6 +52,9 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta): ROCM_AITER_FA = ( "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend" ) + ROCM_AITER_MLA_SPARSE = ( + "vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse.ROCMAiterMLASparseBackend" + ) TORCH_SDPA = "" # this tag is only used for ViT FLASHINFER = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" FLASHINFER_MLA = ( diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 1a2f9226ddce..f9005fd7d044 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -233,10 +233,7 @@ def get_attn_backend_cls( "Sparse MLA backend on ROCm only supports block size 1 for now." ) logger.info_once("Using Sparse MLA backend on V1 engine.") - return ( - "vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse." - "ROCMAiterMLASparseBackend" - ) + return AttentionBackendEnum.ROCM_AITER_MLA_SPARSE.get_path() if use_mla: if selected_backend is None: