Skip to content

Conversation

@ganyi1996ppo
Copy link
Contributor

@ganyi1996ppo ganyi1996ppo commented Oct 13, 2025

Purpose

The PR add Deepseek v3.2 support on ROCm platforms. The main change in this PR include:

  • Replace all hardcode float8_e4m3fn to platform supported fp8 dtype, and add FP8_E4M3FNUZ enum to cpp kernels.
  • Add torch impl to deepgemm
  • Add rocm_aiter_mla_sparse backend and dispatch it in rocm platform

Test Plan

Verify its accuracy on gsm8k and wikitext

Test Result

# wikitext
| Tasks  |Version|Filter|n-shot|    Metric     |   |Value |   |Stderr|
|--------|------:|------|-----:|---------------|---|-----:|---|------|
|wikitext|      2|none  |     0|bits_per_byte  |↓  |0.3691|±  |   N/A|
|        |       |none  |     0|byte_perplexity|↓  |1.2915|±  |   N/A|
|        |       |none  |     0|word_perplexity|↓  |3.9272|±  |   N/A|

# gsm8k

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9629|±  |0.0052|
|     |       |strict-match    |     5|exact_match|↑  |0.9621|±  |0.0053|
 

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for Deepseek v3.2 and SparseMLA on ROCm platforms. The changes are well-structured and cover adding a new FP8 data type for ROCm, a new attention backend for sparse MLA, and PyTorch fallbacks for some custom operations. The code modifications correctly abstract away platform-specific details. I have one suggestion to improve performance in the newly added sparse MLA backend file by avoiding unnecessary tensor initializations.

Comment on lines 156 to 159
out = torch.randn([qs, qh, kv_lora_rank], dtype=torch.bfloat16, device=q.device)
block_table = torch.empty_like(topk_indices) # no use
num_seq = cu_seqlens_q.size(0) - 1
seqused_k = torch.randn([num_seq], device=q.device) # no use
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using torch.randn to initialize tensors that are either output parameters (out) or unused (seqused_k) is inefficient. It populates the tensors with random values, which is an unnecessary computation that can impact performance. It's more performant to use torch.empty, which only allocates uninitialized memory.

Suggested change
out = torch.randn([qs, qh, kv_lora_rank], dtype=torch.bfloat16, device=q.device)
block_table = torch.empty_like(topk_indices) # no use
num_seq = cu_seqlens_q.size(0) - 1
seqused_k = torch.randn([num_seq], device=q.device) # no use
out = torch.empty([qs, qh, kv_lora_rank], dtype=torch.bfloat16, device=q.device)
block_table = torch.empty_like(topk_indices) # no use
num_seq = cu_seqlens_q.size(0) - 1
seqused_k = torch.empty([num_seq], device=q.device) # no use

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same question

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Comment on lines 182 to 198
mask_lo = torch.arange(0, seq_len_kv, device='cuda')[None, :] >= cu_seqlen_ks[:, None]
mask_hi = torch.arange(0, seq_len_kv, device='cuda')[None, :] < cu_seqlen_ke[:, None]
mask = mask_lo & mask_hi

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Use input device instead of 'cuda' in DeepGEMM torch fallback

The new torch implementations of FP8 MQA logits create their temporary tensors with device='cuda' (e.g., the arange masks at fp8_mqa_logits_torch). When the caller runs on a non‑default device such as cuda:1 or HIP, these tensors land on a different device than q/kv, so subsequent ops like mask_lo & mask_hi or matrix multiplies raise a device‑mismatch error. The fallback is meant to run when the compiled kernel is unavailable, so it should follow the device of the inputs (q.device or kv.device) instead of assuming the global CUDA device. The same hardcoded device constant appears again in fp8_paged_mqa_logits_torch and the ROCm branch of fp8_paged_mqa_logits, so all temporaries there should be created on the input device as well.

Useful? React with 👍 / 👎.

@HAIAI HAIAI self-assigned this Oct 13, 2025
@HAIAI HAIAI self-requested a review October 13, 2025 18:25
Copy link
Collaborator

@HAIAI HAIAI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please revert the FP8_E4M3FNUZ part for kv cache, we had uniformed the fp8 KV cache format since 3.2024, and OCP is the only format we support from Quark as well.

Copy link
Collaborator

@HAIAI HAIAI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lint: few more comments
And please format accordingly.

@classmethod
def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]:
if current_platform.is_rocm():
return [MultipleOf(1)]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why 1

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have the same question

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current triton kernel deepgemm_fp8_paged_mqa_logits_stage1 we used to replace deepgemm only support block-size=1 case and so dose our MLA kernel, although it not used in deepseekv3.2. We already have a solution on this, but not in this PR.

)
if use_sparse:
if kv_cache_dtype.startswith("fp8"):
raise ValueError(f"ROCMAiterMLASparseBackend dose not support kv_cache_dtype == fp8.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix typo

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done




def ref_convert_to_gloabl(req_id: torch.Tensor,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines 156 to 159
out = torch.randn([qs, qh, kv_lora_rank], dtype=torch.bfloat16, device=q.device)
block_table = torch.empty_like(topk_indices) # no use
num_seq = cu_seqlens_q.size(0) - 1
seqused_k = torch.randn([num_seq], device=q.device) # no use
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same question

@mergify
Copy link

mergify bot commented Oct 16, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ganyi1996ppo.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 16, 2025
)
if use_sparse:
if kv_cache_dtype.startswith("fp8"):
raise ValueError(f"ROCMAiterMLASparseBackend dose not support kv_cache_dtype == fp8.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: This looks like a RuntimeError

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, this ValueError is meant to tell the users that we don't support fp8 kv cache for now, so we throw value error here.

@classmethod
def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]:
if current_platform.is_rocm():
return [MultipleOf(1)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have the same question

kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
-1, 1, kv_c_and_k_pe_cache.shape[-1])

# NOTE(Chen): kernel requires num_local_head to be a multiple of
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Let's update the comment or remove it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, irrelevant comments have been removed.

@wuhuikx wuhuikx mentioned this pull request Oct 18, 2025
5 tasks
vllm/envs.py Outdated
# By default is enabled.
"VLLM_ROCM_USE_TRITON_ROPE": lambda: (
os.getenv("VLLM_ROCM_USE_TRITON_ROPE", "False").lower() in ("true", "1")
os.getenv("VLLM_ROCM_USE_TRITON_ROPE", "True").lower() in ("true", "1")
Copy link
Collaborator

@tjtanaa tjtanaa Oct 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ganyi1996ppo
To actually call the AITER triton rope code path, we also have to set both VLLM_ROCM_USE_TRITON_ROPE=1 and --compilation-config '{"custom_ops": ["+rotary_embedding"]}' . Only by setting the custom_ops that the forward_hip function will be called, else it will call forward_native by default. I think you will encounter a bug with the DeepSeekRotary Embedding class when trying to use the AITER triton rope.

For my understanding, this AITER triton rope has not been validated for all ROPE classes. It was introduced and tested on Llama-3-405B only. #25135 . So it is better to set it to False for now. Or open another PR to bugfix this and validate on all ROPE classes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR only enables eager mode for dsv3.2, we haven't test graph mode for now. The interesting facts is that Mi300/Mi350(we only test these 2) can only get correct result when using triton rope.

And of course we will investigate this case, and since rope will be broke up to small ops and get fused by torch.inductor in graph mode, this change actually won't effect the default path in graph mode.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, piecewise cudagraph is now supported

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with @tjtanaa , we should set False by default.

@ganyi1996ppo ganyi1996ppo force-pushed the ganyi/dsv3.2_rocm_support branch from 60b33e0 to 8d436d7 Compare October 20, 2025 12:22
@ganyi1996ppo ganyi1996ppo requested a review from HAIAI October 20, 2025 12:23
@mergify mergify bot removed the needs-rebase label Oct 20, 2025
@ganyi1996ppo
Copy link
Contributor Author

ganyi1996ppo commented Oct 20, 2025

Please revert the FP8_E4M3FNUZ part for kv cache, we had uniformed the fp8 KV cache format since 3.2024, and OCP is the only format we support from Quark as well.

@HAIAI Thanks for the advices, all fp8_e4m3fnuz related code have been removed from this PR, the fp8_e4m3fnuz specific part in cpp is now using the __gfx942__ macro to handle

self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
seq_lens, self.kv_cache_spec.block_size, self.num_sms
)
if current_platform.is_cuda():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the code, it appears that we only need this when running with DeepGemm? Can you update the check to if has_deep_gemm()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll refactor this place, thanks~


if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported for MLACommonImpl"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Can you update the string to reference this class

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, thanks to ping that out~


topk_indices = self.topk_indices_buffer[:num_actual_toks]

# Note: the above triton kernel may triggers some strange unexpected
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What kernel is this comment referring to?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This refers to this triton kernel triton_convert_req_index_to_global_index, I'm using a reference torch path to replace it now. But I think we should go triton_convert_req_index_to_global_index eventually for performance consideration, and I'm working on it now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @SageMoore , triton crash issue have already been fixed in newest commit, please take a look.

_lazy_init()
if _fp8_mqa_logits_impl is None:
return _missing()
return fp8_mqa_logits_torch(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke)[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doing this dispatching inside of a file that's just for wrapping Deep Gemm seems incorrect. Let's move this dispatching somewhere else. Looks like indexer.py could be a good spot? CC @LucasWilkinson

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay. Yes, I think we can make it a custom_op and dispatch it outside rather than put it inside the deep_gemm.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have isolate the sparse_attn_indexer function out of the deepseek modeling file. This should look clean than previous version, please take a look again.

`torch.float32`.
"""
_lazy_init()
if current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same idea here. Let's get this out of deep_gemm.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok~

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

)

def dynamic_per_batched_tensor_quant(
x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not torch.dtype = current_platform.fp8_dtype()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay, got caught up by something else. Yes, there should be using platform related dtype, thanks to point that out!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is unnecessary, removed.

@tjtanaa
Copy link
Collaborator

tjtanaa commented Nov 18, 2025

Validated and reproduced results.

Server command:

export VLLM_USE_V1=1
export SAFETENSORS_FAST_GPU=1
export VLLM_ROCM_USE_AITER=1
export VLLM_ROCM_USE_AITER_MOE=1
export NCCL_DEBUG=WARN
export VLLM_LOGGING_LEVEL=DEBUG
export VLLM_RPC_TIMEOUT=18000000

model_path="deepseek-ai/DeepSeek-V3.2-Exp"

vllm serve $model_path \
  --tensor-parallel-size 8 \
  --data-parallel-size 1 \
  --max-num-batched-tokens 32768 \
  --trust-remote-code \
  --no-enable-prefix-caching \
  --disable-log-requests \
  --kv-cache-dtype bfloat16 \
  --gpu_memory_utilization 0.85 \
  --compilation-config '{"cudagraph_mode": "FULL_AND_PIECEWISE"}' \
  --block-size 1 \
2>&1 | tee server-deepseek-ai_DeepSeek-V3.2-Exp.log

lm_eval command

lm_eval --model local-completions  \
   --tasks wikitext  \
   --output_path ./results  \
   --log_samples \
   --model_args model=deepseek-ai/DeepSeek-V3.2-Exp,base_url=http://localhost:8000/v1/completions,num_concurrent=2,max_retries=3,timeout=3000,seed=1234,temperature=0 \
   | tee lmeval_server-deepseek-ai_DeepSeek-V3.2-Exp.log 2>&1

lm_eval \
--model local-completions \
--tasks gsm8k \
--model_args model=deepseek-ai/DeepSeek-V3.2-Exp,base_url=http://127.0.0.1:8000/v1/completions \
--batch_size 1 \
> lmeval_server-deepseek-ai_DeepSeek-V3.2-Exp-gsm8k.log 2>&1

Reproduced.

local-completions (model=deepseek-ai/DeepSeek-V3.2-Exp,base_url=http://localhost:8000/v1/completions,num_concurrent=2,max_retries=3,timeout=3000,seed=1234,temperature=0), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 1
| Tasks  |Version|Filter|n-shot|    Metric     |   |Value |   |Stderr|
|--------|------:|------|-----:|---------------|---|-----:|---|------|
|wikitext|      2|none  |     0|bits_per_byte  |↓  |0.3654|±  |   N/A|
|        |       |none  |     0|byte_perplexity|↓  |1.2882|±  |   N/A|
|        |       |none  |     0|word_perplexity|↓  |3.8745|±  |   N/A|

local-completions (model=deepseek-ai/DeepSeek-V3.2-Exp,base_url=http://127.0.0.1:8000/v1/completions), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|_  |0.9598|_  |0.0054|
|     |       |strict-match    |     5|exact_match|_  |0.9591|_  |0.0055|

@tjtanaa tjtanaa added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 18, 2025
@tjtanaa tjtanaa enabled auto-merge (squash) November 18, 2025 12:30
@tjtanaa
Copy link
Collaborator

tjtanaa commented Nov 19, 2025

@ganyi1996ppo can you try rebasing?

auto-merge was automatically disabled November 19, 2025 01:46

Head branch was pushed to by a user without write access

@ganyi1996ppo ganyi1996ppo force-pushed the ganyi/dsv3.2_rocm_support branch from 1327e75 to 99df437 Compare November 19, 2025 01:46
@tjtanaa
Copy link
Collaborator

tjtanaa commented Nov 20, 2025

@ganyi1996ppo can you try rebasing again. Thank you so much.

@ganyi1996ppo
Copy link
Contributor Author

@ganyi1996ppo can you try rebasing again. Thank you so much.

Sure, but the error seems mainly related to llama and oom issue, Should not caused by this PR right?

@ganyi1996ppo ganyi1996ppo force-pushed the ganyi/dsv3.2_rocm_support branch from 4468193 to acec771 Compare November 20, 2025 03:00
@vllm-bot vllm-bot merged commit 06c20c9 into vllm-project:main Nov 20, 2025
86 of 88 checks passed
LuminolT pushed a commit to LuminolT/vllm that referenced this pull request Nov 21, 2025
lpapavassiliou pushed a commit to lpapavassiliou/vllm that referenced this pull request Nov 24, 2025
RunkaiTao pushed a commit to RunkaiTao/vllm that referenced this pull request Nov 24, 2025
bringlein pushed a commit to bringlein/vllm that referenced this pull request Nov 26, 2025
devpatelio pushed a commit to SumanthRH/vllm that referenced this pull request Nov 29, 2025
kitaekatt pushed a commit to kitaekatt/vllm that referenced this pull request Dec 1, 2025
charlotte12l pushed a commit to charlotte12l/vllm that referenced this pull request Dec 5, 2025
Zhathw pushed a commit to Zhathw/vllm that referenced this pull request Dec 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build deepseek Related to DeepSeek models documentation Improvements or additions to documentation performance Performance-related issues qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm speculative-decoding v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants