Skip to content

Commit 3acf03b

Browse files
SageMooreshreyankg
authored andcommitted
[ROCm] Disable chunked prefill/prefix caching when running MLA on non-cuda platforms (vllm-project#13844)
Signed-off-by: Sage Moore <[email protected]>
1 parent 8295415 commit 3acf03b

File tree

2 files changed

+44
-12
lines changed

2 files changed

+44
-12
lines changed

vllm/attention/backends/mla/common.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,7 @@
232232
from vllm.model_executor.layers.rotary_embedding import (
233233
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
234234
from vllm.multimodal import MultiModalPlaceholderMap
235+
from vllm.platforms import current_platform
235236
from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down
236237

237238
try:
@@ -1371,18 +1372,35 @@ def _forward_prefill(
13711372
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
13721373
value=0)
13731374

1374-
output = self.flash_attn_varlen_func(
1375-
q=q,
1376-
k=k,
1377-
v=v_padded,
1378-
cu_seqlens_q=prefill_metadata.query_start_loc,
1379-
cu_seqlens_k=prefill_metadata.query_start_loc,
1380-
max_seqlen_q=prefill_metadata.max_prefill_seq_len,
1381-
max_seqlen_k=prefill_metadata.max_prefill_seq_len,
1382-
softmax_scale=self.scale,
1383-
causal=True,
1384-
return_softmax_lse=has_context,
1385-
)
1375+
if has_context:
1376+
if not current_platform.is_cuda():
1377+
raise NotImplementedError(
1378+
"Chunked Prefill for MLA is not currently supported on"
1379+
"non-cuda platforms")
1380+
output = self.flash_attn_varlen_func(
1381+
q=q,
1382+
k=k,
1383+
v=v_padded,
1384+
cu_seqlens_q=prefill_metadata.query_start_loc,
1385+
cu_seqlens_k=prefill_metadata.query_start_loc,
1386+
max_seqlen_q=prefill_metadata.max_prefill_seq_len,
1387+
max_seqlen_k=prefill_metadata.max_prefill_seq_len,
1388+
softmax_scale=self.scale,
1389+
causal=True,
1390+
return_softmax_lse=True,
1391+
)
1392+
else:
1393+
output = self.flash_attn_varlen_func(
1394+
q=q,
1395+
k=k,
1396+
v=v_padded,
1397+
cu_seqlens_q=prefill_metadata.query_start_loc,
1398+
cu_seqlens_k=prefill_metadata.query_start_loc,
1399+
max_seqlen_q=prefill_metadata.max_prefill_seq_len,
1400+
max_seqlen_k=prefill_metadata.max_prefill_seq_len,
1401+
softmax_scale=self.scale,
1402+
causal=True,
1403+
)
13861404

13871405
if has_context:
13881406
suffix_output, suffix_lse = output

vllm/config.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3424,6 +3424,20 @@ def __post_init__(self):
34243424
"Disabling `torch.compile`.")
34253425
self.compilation_config.level = CompilationLevel.NO_COMPILATION
34263426

3427+
if self.model_config and self.model_config.use_mla and \
3428+
not current_platform.is_cuda():
3429+
logger.info(
3430+
"MLA is enabled on a non-cuda platform; forcing chunked "
3431+
"prefill and prefix caching to be disabled.")
3432+
self.scheduler_config.enable_chunked_prefill = False
3433+
self.scheduler_config.chunked_prefill_enabled = False
3434+
self.scheduler_config.max_num_batched_tokens = max(
3435+
self.scheduler_config.max_model_len,
3436+
_DEFAULT_MAX_NUM_BATCHED_TOKENS)
3437+
3438+
if self.cache_config is not None:
3439+
self.cache_config.enable_prefix_caching = False
3440+
34273441
current_platform.check_and_update_config(self)
34283442

34293443
if not self.instance_id:

0 commit comments

Comments
 (0)