-
-
Notifications
You must be signed in to change notification settings - Fork 11.9k
[ROCm] Add AMD GPU support on Deepseek v3.2 and SparseMLA #26670
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
abf85c4
8fdafbc
f6dd0c8
e684acb
100f965
68882cc
406fbdb
acec771
47c64db
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,210 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| import importlib | ||
| from functools import lru_cache | ||
|
|
||
| import torch | ||
|
|
||
| from vllm._aiter_ops import rocm_aiter_ops | ||
| from vllm.logger import init_logger | ||
| from vllm.platforms import current_platform | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
|
|
||
| # Take from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_attention.py#L84 | ||
| def fp8_mqa_logits_torch( | ||
| q: torch.Tensor, | ||
| kv: tuple[torch.Tensor, torch.Tensor], | ||
| weights: torch.Tensor, | ||
| cu_seqlen_ks: torch.Tensor, | ||
| cu_seqlen_ke: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| """Compute FP8 MQA logits for a single sequence without KV paging. | ||
|
|
||
| Args: | ||
| q: Query tensor of shape [M, H, D]. Casted to | ||
| `torch.float8_e4m3fn` by caller. | ||
| kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with | ||
| dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or | ||
| [N, 1]) with dtype `torch.float32`. | ||
| weights: weights of shape [M, H], dtype `torch.float32`. | ||
| cu_seqlen_ks: Start indices (inclusive) for valid K per query position, | ||
| shape [M], dtype int32. | ||
| cu_seqlen_ke: End indices (exclusive) for valid K per query position, | ||
| shape [M], dtype int32. | ||
|
|
||
| Returns: | ||
| Logits tensor of shape [M, N], dtype `torch.float32`. | ||
| """ | ||
| kv, scale = kv | ||
| seq_len_kv = kv.shape[0] | ||
| k = kv.to(torch.bfloat16) | ||
| q = q.to(torch.bfloat16) | ||
|
|
||
| mask_lo = ( | ||
| torch.arange(0, seq_len_kv, device="cuda")[None, :] >= cu_seqlen_ks[:, None] | ||
| ) | ||
| mask_hi = ( | ||
| torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None] | ||
| ) | ||
| mask = mask_lo & mask_hi | ||
|
|
||
| score = torch.einsum("mhd,nd->hmn", q, k).float() * scale | ||
| logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0) | ||
| logits = logits.masked_fill(~mask, float("-inf")) | ||
|
|
||
| return logits | ||
|
|
||
|
|
||
| def rocm_fp8_mqa_logits( | ||
| q: torch.Tensor, | ||
| kv: tuple[torch.Tensor, torch.Tensor], | ||
| weights: torch.Tensor, | ||
| cu_seqlen_ks: torch.Tensor, | ||
| cu_seqlen_ke: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| """Compute FP8 MQA logits for a single sequence without KV paging. | ||
|
|
||
| Args: | ||
| q: Query tensor of shape [M, H, D]. Casted to | ||
| `torch.float8_e4m3fn` by caller. | ||
| kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with | ||
| dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or | ||
| [N, 1]) with dtype `torch.float32`. | ||
| weights: weights of shape [M, H], dtype `torch.float32`. | ||
| cu_seqlen_ks: Start indices (inclusive) for valid K per query position, | ||
| shape [M], dtype int32. | ||
| cu_seqlen_ke: End indices (exclusive) for valid K per query position, | ||
| shape [M], dtype int32. | ||
|
|
||
| Returns: | ||
| Logits tensor of shape [M, N], dtype `torch.float32`. | ||
| """ | ||
|
|
||
| # TODO(ganyi): Temporarily workaround, will remove the module check and reference | ||
| # path after aiter merge this kernel into main | ||
| @lru_cache | ||
| def has_mqa_logits_module(): | ||
| return importlib.util.find_spec("aiter.ops.triton.fp8_mqa_logits") is not None | ||
|
|
||
| if rocm_aiter_ops.is_enabled() and has_mqa_logits_module(): | ||
| from aiter.ops.triton.fp8_mqa_logits import fp8_mqa_logits | ||
|
|
||
| kv, scale = kv | ||
| return fp8_mqa_logits(q, kv, scale, weights, cu_seqlen_ks, cu_seqlen_ke) | ||
| else: | ||
| return fp8_mqa_logits_torch(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke) | ||
|
|
||
|
|
||
| # Taken from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_attention.py#L156 | ||
| def fp8_paged_mqa_logits_torch( | ||
| q: torch.Tensor, | ||
| kv_cache: torch.Tensor, | ||
| weights: torch.Tensor, | ||
| context_lens: torch.Tensor, | ||
| block_tables: torch.Tensor, | ||
| max_model_len: int, | ||
| ): | ||
| from vllm.utils.math_utils import cdiv | ||
|
|
||
| fp8_dtype = current_platform.fp8_dtype() | ||
| batch_size, next_n, _, dim = q.size() | ||
| kv_cache, scale = kv_cache[..., :dim], kv_cache[..., dim:] | ||
| scale = scale.contiguous().view(torch.float) | ||
| q = q.float() | ||
| kv_cache = kv_cache.view(fp8_dtype).float() * scale | ||
| num_block, block_size, _, dim = kv_cache.size() | ||
| logits = torch.full( | ||
| [batch_size * next_n, max_model_len], | ||
| float("-inf"), | ||
| device=q.device, | ||
| dtype=torch.float32, | ||
| ) | ||
| context_lens = context_lens.tolist() | ||
| for i in range(batch_size): | ||
| context_len = context_lens[i] | ||
| q_offsets = torch.arange(context_len - next_n, context_len, device="cuda") | ||
| weight_slice = ( | ||
| weights[i * next_n : (i + 1) * next_n, :].transpose(0, 1).contiguous() | ||
| ) | ||
| for block_rk in range(cdiv(context_len, block_size)): | ||
| block_idx = block_tables[i][block_rk] | ||
| qx, kx = q[i], kv_cache[block_idx] | ||
| k_offsets = torch.arange( | ||
| block_rk * block_size, (block_rk + 1) * block_size, device="cuda" | ||
| ) | ||
| mask = (k_offsets[None, :] < context_len) & ( | ||
| k_offsets[None, :] <= q_offsets[:, None] | ||
| ) | ||
| s = torch.where( | ||
| mask[None, :, :], | ||
| (qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to( | ||
| logits.dtype | ||
| ), | ||
| float("-inf"), | ||
| ) | ||
| s = torch.relu(s) * weight_slice[..., None] | ||
| s = s.sum(dim=0) | ||
| logits[ | ||
| i * next_n : (i + 1) * next_n, | ||
| block_rk * block_size : (block_rk + 1) * block_size, | ||
| ] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float("-inf")) | ||
| return logits | ||
|
|
||
|
|
||
| def rocm_fp8_paged_mqa_logits( | ||
| q_fp8: torch.Tensor, | ||
| kv_cache_fp8: torch.Tensor, | ||
| weights: torch.Tensor, | ||
| context_lens: torch.Tensor, | ||
| block_tables: torch.Tensor, | ||
| schedule_metadata: torch.Tensor, | ||
| max_model_len: int, | ||
| ) -> torch.Tensor: | ||
| """Compute FP8 MQA logits using paged KV-cache. | ||
|
|
||
| Args: | ||
| q_fp8: Query tensor of shape [B, next_n, H, D]. Casted to | ||
| `torch.float8_e4m3fn` by caller. | ||
| kv_cache_fp8: Paged KV-cache in packed FP8+scale layout with shape | ||
| [num_blocks, block_size, 1, D+4], dtype `torch.uint8`. The last | ||
| 4 bytes per (block,pos) store the `float` dequant scale. | ||
| weights: Tensor of shape [B * next_n, H], dtype `torch.float32`. | ||
| context_lens: Tensor of shape [B], dtype int32; effective context length | ||
| for each batch element. | ||
| block_tables: Tensor of shape [B, max_blocks], dtype int32; maps logical | ||
| block indices to physical blocks in the paged cache. | ||
| schedule_metadata: Returned by `get_paged_mqa_logits_metadata`; | ||
| used to distribute work across SMs. | ||
| max_model_len: Maximum sequence length used to size the logits output. | ||
|
|
||
| Returns: | ||
| Logits tensor of shape [B * next_n, max_model_len], dtype | ||
| `torch.float32`. | ||
| """ | ||
|
|
||
| if rocm_aiter_ops.is_enabled(): | ||
| from aiter.ops.triton.pa_mqa_logits import deepgemm_fp8_paged_mqa_logits_stage1 | ||
|
|
||
| batch_size, next_n, heads, _ = q_fp8.shape | ||
| out_qk = torch.full( | ||
| (heads, batch_size * next_n, max_model_len), | ||
| float("-inf"), | ||
| device="cuda", | ||
| dtype=torch.float32, | ||
| ) | ||
| deepgemm_fp8_paged_mqa_logits_stage1( | ||
| q_fp8, | ||
| kv_cache_fp8, | ||
| weights, | ||
| out_qk, | ||
| context_lens, | ||
| block_tables, | ||
| max_model_len, | ||
| ) | ||
| return out_qk.sum(dim=0) | ||
| else: | ||
| return fp8_paged_mqa_logits_torch( | ||
| q_fp8, kv_cache_fp8, weights, context_lens, block_tables, max_model_len | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -225,7 +225,18 @@ def get_attn_backend_cls( | |
| from vllm.attention.backends.registry import AttentionBackendEnum | ||
|
|
||
| if use_sparse: | ||
| raise NotImplementedError("Sparse Attention is not supported on ROCm.") | ||
| if kv_cache_dtype.startswith("fp8"): | ||
| raise ValueError( | ||
| "ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype." | ||
| ) | ||
| assert block_size == 1, ( | ||
| "Sparse MLA backend on ROCm only supports block size 1 for now." | ||
| ) | ||
| logger.info_once("Using Sparse MLA backend on V1 engine.") | ||
| return ( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ganyi1996ppo what do you think about adding an assertion of block-size 1 when we are using this backend?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done, please take a look. |
||
| "vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse." | ||
| "ROCMAiterMLASparseBackend" | ||
| ) | ||
|
|
||
| if use_mla: | ||
| if selected_backend is None: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.