Skip to content

Commit a72c0c3

Browse files
SageMooreshreyankg
authored andcommitted
[ROCm] Enable chunked prefill/paged attention in MLA on ROCm (vllm-project#14316)
Signed-off-by: Sage Moore <[email protected]>
1 parent 9fa8fb3 commit a72c0c3

File tree

2 files changed

+4
-18
lines changed

2 files changed

+4
-18
lines changed

vllm/attention/backends/mla/common.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1327,21 +1327,7 @@ def _compute_prefill_context(
13271327
[0, q.shape[-1] - v.shape[-1]],
13281328
value=0)
13291329

1330-
if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN:
1331-
attn_output, attn_softmax_lse = self.triton_fa_func(
1332-
q,
1333-
k,
1334-
v_padded,
1335-
None,
1336-
prefill_metadata.query_start_loc,
1337-
prefill_metadata.context_chunk_cu_seq_lens[i],
1338-
prefill_metadata.max_query_len,
1339-
prefill_metadata.context_chunk_max_seq_lens[i],
1340-
False, # causal
1341-
self.scale,
1342-
None, # attn_mask is None unless applying ALiBi mask
1343-
)
1344-
elif is_vllm_fa:
1330+
if is_vllm_fa:
13451331
attn_output, attn_softmax_lse = self.flash_attn_varlen_func(
13461332
q=q,
13471333
k=k,
@@ -1416,7 +1402,7 @@ def _forward_prefill(
14161402
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
14171403
value=0)
14181404

1419-
if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN:
1405+
if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN and not has_context:
14201406
output = self.triton_fa_func(
14211407
q,
14221408
k,

vllm/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3452,9 +3452,9 @@ def __post_init__(self):
34523452
self.compilation_config.level = CompilationLevel.NO_COMPILATION
34533453

34543454
if self.model_config and self.model_config.use_mla and \
3455-
not current_platform.is_cuda():
3455+
not (current_platform.is_cuda() or current_platform.is_rocm()):
34563456
logger.info(
3457-
"MLA is enabled on a non-cuda platform; forcing chunked "
3457+
"MLA is enabled on a non-GPU platform; forcing chunked "
34583458
"prefill and prefix caching to be disabled.")
34593459
self.scheduler_config.enable_chunked_prefill = False
34603460
self.scheduler_config.chunked_prefill_enabled = False

0 commit comments

Comments
 (0)