From 07336d23dbd6952312d2d45c18a7c3b8ad31379c Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Tue, 25 Feb 2025 18:52:19 +0000 Subject: [PATCH 1/3] init Signed-off-by: Sage Moore --- vllm/attention/backends/mla/common.py | 42 +++++++++++++++++++-------- 1 file changed, 30 insertions(+), 12 deletions(-) diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index c3dbbdb86823..ac8e08b8e59f 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -232,6 +232,7 @@ from vllm.model_executor.layers.rotary_embedding import ( DeepseekScalingRotaryEmbedding, RotaryEmbedding) from vllm.multimodal import MultiModalPlaceholderMap +from vllm.platforms import current_platform from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down try: @@ -1371,18 +1372,35 @@ def _forward_prefill( v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], value=0) - output = self.flash_attn_varlen_func( - q=q, - k=k, - v=v_padded, - cu_seqlens_q=prefill_metadata.query_start_loc, - cu_seqlens_k=prefill_metadata.query_start_loc, - max_seqlen_q=prefill_metadata.max_prefill_seq_len, - max_seqlen_k=prefill_metadata.max_prefill_seq_len, - softmax_scale=self.scale, - causal=True, - return_softmax_lse=has_context, - ) + if has_context: + if not current_platform.is_cuda(): + raise NotImplementedError( + "Chunked Prefill for MLA is not currently supported on ROCm" + ) + output = self.flash_attn_varlen_func( + q=q, + k=k, + v=v_padded, + cu_seqlens_q=prefill_metadata.query_start_loc, + cu_seqlens_k=prefill_metadata.query_start_loc, + max_seqlen_q=prefill_metadata.max_prefill_seq_len, + max_seqlen_k=prefill_metadata.max_prefill_seq_len, + softmax_scale=self.scale, + causal=True, + return_softmax_lse=True, + ) + else: + output = self.flash_attn_varlen_func( + q=q, + k=k, + v=v_padded, + cu_seqlens_q=prefill_metadata.query_start_loc, + cu_seqlens_k=prefill_metadata.query_start_loc, + max_seqlen_q=prefill_metadata.max_prefill_seq_len, + max_seqlen_k=prefill_metadata.max_prefill_seq_len, + softmax_scale=self.scale, + causal=True, + ) if has_context: suffix_output, suffix_lse = output From c226a3060421a5a9788603b4970704d66fbcb526 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Tue, 25 Feb 2025 19:26:17 +0000 Subject: [PATCH 2/3] init Signed-off-by: Sage Moore --- vllm/config.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/vllm/config.py b/vllm/config.py index 6bcf34c3cff9..ede95d818189 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3400,6 +3400,20 @@ def __post_init__(self): "Disabling `torch.compile`.") self.compilation_config.level = CompilationLevel.NO_COMPILATION + if self.model_config and self.model_config.use_mla and \ + not current_platform.is_cuda(): + logger.info( + "MLA is enabled on ROCm; forcing chunked prefill and prefix " + "caching to be disabled.") + self.scheduler_config.enable_chunked_prefill = False + self.scheduler_config.chunked_prefill_enabled = False + self.scheduler_config.max_num_batched_tokens = max( + self.scheduler_config.max_model_len, + _DEFAULT_MAX_NUM_BATCHED_TOKENS) + + if self.cache_config is not None: + self.cache_config.enable_prefix_caching = False + current_platform.check_and_update_config(self) if not self.instance_id: From ae3594eb4d5f4bbf1e5da1187b98f46232d18cb5 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Tue, 25 Feb 2025 19:35:39 +0000 Subject: [PATCH 3/3] update logs Signed-off-by: Sage Moore --- vllm/attention/backends/mla/common.py | 4 ++-- vllm/config.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index ac8e08b8e59f..529ec3f3b11a 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -1375,8 +1375,8 @@ def _forward_prefill( if has_context: if not current_platform.is_cuda(): raise NotImplementedError( - "Chunked Prefill for MLA is not currently supported on ROCm" - ) + "Chunked Prefill for MLA is not currently supported on" + "non-cuda platforms") output = self.flash_attn_varlen_func( q=q, k=k, diff --git a/vllm/config.py b/vllm/config.py index ede95d818189..111009131d83 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3403,8 +3403,8 @@ def __post_init__(self): if self.model_config and self.model_config.use_mla and \ not current_platform.is_cuda(): logger.info( - "MLA is enabled on ROCm; forcing chunked prefill and prefix " - "caching to be disabled.") + "MLA is enabled on a non-cuda platform; forcing chunked " + "prefill and prefix caching to be disabled.") self.scheduler_config.enable_chunked_prefill = False self.scheduler_config.chunked_prefill_enabled = False self.scheduler_config.max_num_batched_tokens = max(