-
-
Notifications
You must be signed in to change notification settings - Fork 11.9k
[ROCm] Add AMD GPU support on Deepseek v3.2 and SparseMLA #26670
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for Deepseek v3.2 and SparseMLA on ROCm platforms. The changes are well-structured and cover adding a new FP8 data type for ROCm, a new attention backend for sparse MLA, and PyTorch fallbacks for some custom operations. The code modifications correctly abstract away platform-specific details. I have one suggestion to improve performance in the newly added sparse MLA backend file by avoiding unnecessary tensor initializations.
| out = torch.randn([qs, qh, kv_lora_rank], dtype=torch.bfloat16, device=q.device) | ||
| block_table = torch.empty_like(topk_indices) # no use | ||
| num_seq = cu_seqlens_q.size(0) - 1 | ||
| seqused_k = torch.randn([num_seq], device=q.device) # no use |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using torch.randn to initialize tensors that are either output parameters (out) or unused (seqused_k) is inefficient. It populates the tensors with random values, which is an unnecessary computation that can impact performance. It's more performant to use torch.empty, which only allocates uninitialized memory.
| out = torch.randn([qs, qh, kv_lora_rank], dtype=torch.bfloat16, device=q.device) | |
| block_table = torch.empty_like(topk_indices) # no use | |
| num_seq = cu_seqlens_q.size(0) - 1 | |
| seqused_k = torch.randn([num_seq], device=q.device) # no use | |
| out = torch.empty([qs, qh, kv_lora_rank], dtype=torch.bfloat16, device=q.device) | |
| block_table = torch.empty_like(topk_indices) # no use | |
| num_seq = cu_seqlens_q.size(0) - 1 | |
| seqused_k = torch.empty([num_seq], device=q.device) # no use |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same question
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
vllm/utils/deep_gemm.py
Outdated
| mask_lo = torch.arange(0, seq_len_kv, device='cuda')[None, :] >= cu_seqlen_ks[:, None] | ||
| mask_hi = torch.arange(0, seq_len_kv, device='cuda')[None, :] < cu_seqlen_ke[:, None] | ||
| mask = mask_lo & mask_hi |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use input device instead of 'cuda' in DeepGEMM torch fallback
The new torch implementations of FP8 MQA logits create their temporary tensors with device='cuda' (e.g., the arange masks at fp8_mqa_logits_torch). When the caller runs on a non‑default device such as cuda:1 or HIP, these tensors land on a different device than q/kv, so subsequent ops like mask_lo & mask_hi or matrix multiplies raise a device‑mismatch error. The fallback is meant to run when the compiled kernel is unavailable, so it should follow the device of the inputs (q.device or kv.device) instead of assuming the global CUDA device. The same hardcoded device constant appears again in fp8_paged_mqa_logits_torch and the ROCm branch of fp8_paged_mqa_logits, so all temporaries there should be created on the input device as well.
Useful? React with 👍 / 👎.
HAIAI
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please revert the FP8_E4M3FNUZ part for kv cache, we had uniformed the fp8 KV cache format since 3.2024, and OCP is the only format we support from Quark as well.
HAIAI
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lint: few more comments
And please format accordingly.
| @classmethod | ||
| def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]: | ||
| if current_platform.is_rocm(): | ||
| return [MultipleOf(1)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why 1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have the same question
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current triton kernel deepgemm_fp8_paged_mqa_logits_stage1 we used to replace deepgemm only support block-size=1 case and so dose our MLA kernel, although it not used in deepseekv3.2. We already have a solution on this, but not in this PR.
vllm/platforms/rocm.py
Outdated
| ) | ||
| if use_sparse: | ||
| if kv_cache_dtype.startswith("fp8"): | ||
| raise ValueError(f"ROCMAiterMLASparseBackend dose not support kv_cache_dtype == fp8.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fix typo
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
|
|
||
|
|
||
|
|
||
| def ref_convert_to_gloabl(req_id: torch.Tensor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
| out = torch.randn([qs, qh, kv_lora_rank], dtype=torch.bfloat16, device=q.device) | ||
| block_table = torch.empty_like(topk_indices) # no use | ||
| num_seq = cu_seqlens_q.size(0) - 1 | ||
| seqused_k = torch.randn([num_seq], device=q.device) # no use |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same question
|
This pull request has merge conflicts that must be resolved before it can be |
vllm/platforms/rocm.py
Outdated
| ) | ||
| if use_sparse: | ||
| if kv_cache_dtype.startswith("fp8"): | ||
| raise ValueError(f"ROCMAiterMLASparseBackend dose not support kv_cache_dtype == fp8.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: This looks like a RuntimeError
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, this ValueError is meant to tell the users that we don't support fp8 kv cache for now, so we throw value error here.
| @classmethod | ||
| def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]: | ||
| if current_platform.is_rocm(): | ||
| return [MultipleOf(1)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have the same question
| kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view( | ||
| -1, 1, kv_c_and_k_pe_cache.shape[-1]) | ||
|
|
||
| # NOTE(Chen): kernel requires num_local_head to be a multiple of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Let's update the comment or remove it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, irrelevant comments have been removed.
vllm/envs.py
Outdated
| # By default is enabled. | ||
| "VLLM_ROCM_USE_TRITON_ROPE": lambda: ( | ||
| os.getenv("VLLM_ROCM_USE_TRITON_ROPE", "False").lower() in ("true", "1") | ||
| os.getenv("VLLM_ROCM_USE_TRITON_ROPE", "True").lower() in ("true", "1") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ganyi1996ppo
To actually call the AITER triton rope code path, we also have to set both VLLM_ROCM_USE_TRITON_ROPE=1 and --compilation-config '{"custom_ops": ["+rotary_embedding"]}' . Only by setting the custom_ops that the forward_hip function will be called, else it will call forward_native by default. I think you will encounter a bug with the DeepSeekRotary Embedding class when trying to use the AITER triton rope.
For my understanding, this AITER triton rope has not been validated for all ROPE classes. It was introduced and tested on Llama-3-405B only. #25135 . So it is better to set it to False for now. Or open another PR to bugfix this and validate on all ROPE classes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR only enables eager mode for dsv3.2, we haven't test graph mode for now. The interesting facts is that Mi300/Mi350(we only test these 2) can only get correct result when using triton rope.
And of course we will investigate this case, and since rope will be broke up to small ops and get fused by torch.inductor in graph mode, this change actually won't effect the default path in graph mode.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI, piecewise cudagraph is now supported
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree with @tjtanaa , we should set False by default.
60b33e0 to
8d436d7
Compare
@HAIAI Thanks for the advices, all |
| self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata( | ||
| seq_lens, self.kv_cache_spec.block_size, self.num_sms | ||
| ) | ||
| if current_platform.is_cuda(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at the code, it appears that we only need this when running with DeepGemm? Can you update the check to if has_deep_gemm()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll refactor this place, thanks~
|
|
||
| if output_scale is not None or output_block_scale is not None: | ||
| raise NotImplementedError( | ||
| "fused output quantization is not yet supported for MLACommonImpl" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Can you update the string to reference this class
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, thanks to ping that out~
|
|
||
| topk_indices = self.topk_indices_buffer[:num_actual_toks] | ||
|
|
||
| # Note: the above triton kernel may triggers some strange unexpected |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What kernel is this comment referring to?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This refers to this triton kernel triton_convert_req_index_to_global_index, I'm using a reference torch path to replace it now. But I think we should go triton_convert_req_index_to_global_index eventually for performance consideration, and I'm working on it now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @SageMoore , triton crash issue have already been fixed in newest commit, please take a look.
vllm/utils/deep_gemm.py
Outdated
| _lazy_init() | ||
| if _fp8_mqa_logits_impl is None: | ||
| return _missing() | ||
| return fp8_mqa_logits_torch(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke)[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doing this dispatching inside of a file that's just for wrapping Deep Gemm seems incorrect. Let's move this dispatching somewhere else. Looks like indexer.py could be a good spot? CC @LucasWilkinson
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the delay. Yes, I think we can make it a custom_op and dispatch it outside rather than put it inside the deep_gemm.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have isolate the sparse_attn_indexer function out of the deepseek modeling file. This should look clean than previous version, please take a look again.
vllm/utils/deep_gemm.py
Outdated
| `torch.float32`. | ||
| """ | ||
| _lazy_init() | ||
| if current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same idea here. Let's get this out of deep_gemm.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
| ) | ||
|
|
||
| def dynamic_per_batched_tensor_quant( | ||
| x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not torch.dtype = current_platform.fp8_dtype()?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the delay, got caught up by something else. Yes, there should be using platform related dtype, thanks to point that out!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is unnecessary, removed.
|
Validated and reproduced results. Server command: lm_eval command Reproduced. |
|
@ganyi1996ppo can you try rebasing? |
Head branch was pushed to by a user without write access
1327e75 to
99df437
Compare
|
@ganyi1996ppo can you try rebasing again. Thank you so much. |
Sure, but the error seems mainly related to llama and oom issue, Should not caused by this PR right? |
Signed-off-by: ganyi <[email protected]>
Signed-off-by: ganyi <[email protected]>
Signed-off-by: ganyi <[email protected]>
Signed-off-by: ganyi <[email protected]>
Signed-off-by: ganyi <[email protected]>
Signed-off-by: ganyi <[email protected]>
Signed-off-by: ganyi <[email protected]>
Signed-off-by: ganyi <[email protected]>
4468193 to
acec771
Compare
…ct#26670) Signed-off-by: ganyi <[email protected]> Signed-off-by: LuminolT <[email protected]>
…ct#26670) Signed-off-by: ganyi <[email protected]>
…ct#26670) Signed-off-by: ganyi <[email protected]> Signed-off-by: Runkai Tao <[email protected]>
…ct#26670) Signed-off-by: ganyi <[email protected]>
…ct#26670) Signed-off-by: ganyi <[email protected]>
…ct#26670) Signed-off-by: ganyi <[email protected]>
…ct#26670) Signed-off-by: ganyi <[email protected]> Signed-off-by: Xingyu Liu <[email protected]>
…ct#26670) Signed-off-by: ganyi <[email protected]>
Purpose
The PR add Deepseek v3.2 support on ROCm platforms. The main change in this PR include:
Test Plan
Verify its accuracy on gsm8k and wikitext
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.