Skip to content

Conversation

@wenscarl
Copy link
Contributor

@wenscarl wenscarl commented Oct 2, 2025

This matches trt-llm usage. Use all experts' input scales to compute alpha and quantize. This change doesn't affect perf.

Purpose

Test Plan

VLLM_WORKER_MULTIPROC_METHOD="spawn" \
VLLM_USE_STANDALONE_COMPILE=0 \
VLLM_USE_FLASHINFER_MOE_FP4=1 \
VLLM_FLASHINFER_MOE_BACKEND="throughput" \
lm_eval --model vllm --model_args pretrained=nvidia/DeepSeek-R1-FP4,quantization=modelopt_fp4,data_parallel_size=8,enable_expert_parallel=True,tensor_parallel_size=1,max_model_len=2048,enforce_eager=True --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto

Test Result

After change:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9431|±  |0.0064|
|     |       |strict-match    |     5|exact_match|↑  |0.9401|±  |0.0065|```

Before change:
Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.934 ± 0.0068
strict-match 5 exact_match 0.931 ± 0.0070
---
<details>
<summary> Essential Elements of an Effective PR Description Checklist </summary>

- [ ] The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
- [ ] The test plan, such as providing test command.
- [ ] The test results, such as pasting the results comparison before and after, or e2e results
- [ ] (Optional) The necessary documentation update, such as updating `supported_models.md` and `examples` for a new model.
- [ ] (Optional) Release notes update. If your change is user facing, please update the release notes draft in the [Google Doc](https://docs.google.com/document/d/1YyVqrgX4gHTtrstbq8oWUImOyPCKSGnJ7xtTpmXzlRs/edit?tab=t.0).
</details>

@wenscarl wenscarl changed the title Load w13/w2_input_scale for all experts [ModelOpt] Load w13/w2_input_scale for all experts, nvfp4 Oct 3, 2025
@mergify
Copy link

mergify bot commented Oct 7, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @wenscarl.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 7, 2025
@wenscarl wenscarl marked this pull request as ready for review October 7, 2025 02:53
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

@mergify mergify bot removed the needs-rebase label Oct 7, 2025
@wenscarl wenscarl force-pushed the nvfp4_glb_sf branch 2 times, most recently from 49543bd to cbfdd6d Compare October 14, 2025 03:54
@wenscarl wenscarl requested a review from pavanimajety October 17, 2025 14:53
@leejnau
Copy link
Contributor

leejnau commented Oct 17, 2025

For the nvidia/DeepSeek-R1-0528-FP4-v2 model, in both TP4 and DP4 modes, with the FlashInfer backend, this PR raises the accuracy from ~2% to ~95%.

server (TP4):

VLLM_USE_FLASHINFER_MOE_FP4=1 VLLM_FLASHINFER_MOE_BACKEND="throughput" vllm serve nvidia/DeepSeek-R1-0528-FP4-v2 --quantization="modelopt_fp4" --trust-remote-code --gpu_memory_utilization=0.8 --tensor-parallel-size 4 --data-parallel-size 1

client:

python3 tests/evals/gsm8k/gsm8k_eval.py
commit: 29350922c64a808a6de3b0e31fbadc2aebd6ba3f
Accuracy: 0.025
commit: 281be34de010fd4b106341e8aa3996f01f121c61
Accuracy: 0.955

server (DP4):

VLLM_USE_FLASHINFER_MOE_FP4=1 VLLM_FLASHINFER_MOE_BACKEND="throughput" vllm serve nvidia/DeepSeek-R1-0528-FP4-v2 --quantization="modelopt_fp4" --trust-remote-code --gpu_memory_utilization=0.8 --tensor-parallel-size 1 --data-parallel-size 4

client:

python3 tests/evals/gsm8k/gsm8k_eval.py
commit: 29350922c64a808a6de3b0e31fbadc2aebd6ba3f
Accuracy: 0.022
commit: 281be34de010fd4b106341e8aa3996f01f121c61
Accuracy: 0.955

Signed-off-by: Shu Wang. <[email protected]>
@mgoin mgoin added bug Something isn't working quantization ready ONLY add when PR is ready to merge/full CI is needed deepseek Related to DeepSeek models labels Oct 20, 2025
Comment on lines 1639 to 1641
# Case input scale: input_scale loading is only supported for fp8
if "input_scale" in weight_name:
# this is needed for compressed-tensors only
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: looks like these comments are out of date

Comment on lines +1557 to +1565
allow_flashinfer = getattr(self.quant_method, "allow_flashinfer", False)
moe_backend = getattr(self.quant_method, "flashinfer_moe_backend", None)

use_global_sf = (
allow_flashinfer
and is_flashinfer_supporting_global_sf(moe_backend)
and "input_scale" in weight_name
and quant_method_name == "ModelOptNvFp4FusedMoE"
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to put these three lines together and leave a comment on what use_global_sf means in this case since we are in fused_moe/layer.py

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, although checking for local attrs in fused_moe/layer.py doesn't feel good to keep doing I don't have a better option atm

@mgoin mgoin merged commit f95da13 into vllm-project:main Oct 21, 2025
57 checks passed
Zhuul pushed a commit to Zhuul/vllm that referenced this pull request Oct 21, 2025
baonudesifeizhai pushed a commit to baonudesifeizhai/vllm that referenced this pull request Oct 21, 2025
albertoperdomo2 pushed a commit to albertoperdomo2/vllm that referenced this pull request Oct 23, 2025
…ct#26135)

Signed-off-by: Shu Wang <[email protected]>
Signed-off-by: Shu Wang. <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
Signed-off-by: Alberto Perdomo <[email protected]>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
…ct#26135)

Signed-off-by: Shu Wang <[email protected]>
Signed-off-by: Shu Wang. <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
Signed-off-by: 0xrushi <[email protected]>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
…ct#26135)

Signed-off-by: Shu Wang <[email protected]>
Signed-off-by: Shu Wang. <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
Signed-off-by: 0xrushi <[email protected]>
Chenyaaang pushed a commit to Chenyaaang/vllm that referenced this pull request Oct 28, 2025
ilmarkov pushed a commit to neuralmagic/vllm that referenced this pull request Nov 7, 2025
rtourgeman pushed a commit to rtourgeman/vllm that referenced this pull request Nov 10, 2025
devpatelio pushed a commit to SumanthRH/vllm that referenced this pull request Nov 29, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working deepseek Related to DeepSeek models quantization ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants