|
232 | 232 | from vllm.model_executor.layers.rotary_embedding import ( |
233 | 233 | DeepseekScalingRotaryEmbedding, RotaryEmbedding) |
234 | 234 | from vllm.multimodal import MultiModalPlaceholderMap |
| 235 | +from vllm.platforms import current_platform |
235 | 236 | from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down |
236 | 237 |
|
237 | 238 | try: |
@@ -1371,18 +1372,35 @@ def _forward_prefill( |
1371 | 1372 | v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], |
1372 | 1373 | value=0) |
1373 | 1374 |
|
1374 | | - output = self.flash_attn_varlen_func( |
1375 | | - q=q, |
1376 | | - k=k, |
1377 | | - v=v_padded, |
1378 | | - cu_seqlens_q=prefill_metadata.query_start_loc, |
1379 | | - cu_seqlens_k=prefill_metadata.query_start_loc, |
1380 | | - max_seqlen_q=prefill_metadata.max_prefill_seq_len, |
1381 | | - max_seqlen_k=prefill_metadata.max_prefill_seq_len, |
1382 | | - softmax_scale=self.scale, |
1383 | | - causal=True, |
1384 | | - return_softmax_lse=has_context, |
1385 | | - ) |
| 1375 | + if has_context: |
| 1376 | + if not current_platform.is_cuda(): |
| 1377 | + raise NotImplementedError( |
| 1378 | + "Chunked Prefill for MLA is not currently supported on" |
| 1379 | + "non-cuda platforms") |
| 1380 | + output = self.flash_attn_varlen_func( |
| 1381 | + q=q, |
| 1382 | + k=k, |
| 1383 | + v=v_padded, |
| 1384 | + cu_seqlens_q=prefill_metadata.query_start_loc, |
| 1385 | + cu_seqlens_k=prefill_metadata.query_start_loc, |
| 1386 | + max_seqlen_q=prefill_metadata.max_prefill_seq_len, |
| 1387 | + max_seqlen_k=prefill_metadata.max_prefill_seq_len, |
| 1388 | + softmax_scale=self.scale, |
| 1389 | + causal=True, |
| 1390 | + return_softmax_lse=True, |
| 1391 | + ) |
| 1392 | + else: |
| 1393 | + output = self.flash_attn_varlen_func( |
| 1394 | + q=q, |
| 1395 | + k=k, |
| 1396 | + v=v_padded, |
| 1397 | + cu_seqlens_q=prefill_metadata.query_start_loc, |
| 1398 | + cu_seqlens_k=prefill_metadata.query_start_loc, |
| 1399 | + max_seqlen_q=prefill_metadata.max_prefill_seq_len, |
| 1400 | + max_seqlen_k=prefill_metadata.max_prefill_seq_len, |
| 1401 | + softmax_scale=self.scale, |
| 1402 | + causal=True, |
| 1403 | + ) |
1386 | 1404 |
|
1387 | 1405 | if has_context: |
1388 | 1406 | suffix_output, suffix_lse = output |
|
0 commit comments