Skip to content

Conversation

@yewentao256
Copy link
Member

@yewentao256 yewentao256 commented Sep 2, 2025

Purpose

Fixes #24118 that is introduced by #23123

Test

lm_eval --model local-completions --model_args "base_url=http://127.0.0.1:9256/v1/completions,model=deepseek-ai/DeepSeek-R1,num_concurrent=256" --tasks gsm8k

Origin

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match||0.6626|±  |0.0130|
|     |       |strict-match    |     5|exact_match||0.5201|±  |0.0138|

Now

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match||0.9530|±  |0.0058|
|     |       |strict-match    |     5|exact_match||0.9522|±  |0.0059|

@mergify mergify bot added the deepseek Related to DeepSeek models label Sep 2, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly fixes a bug in the DeepseekV2MoE forward pass where routed_scaling_factor was being applied twice for non-fp16 data types. The self.experts layer (FusedMoE) already applies this scaling factor internally. By removing the redundant multiplication, the code is now more correct. I have also added a review comment highlighting a potential related scaling issue in the handling of shared_output that may require further investigation.

@yewentao256 yewentao256 changed the title [Bug] Fix routed_scaling_factor Double Mul Issue [Bug] R1 Accuracy: Fix routed_scaling_factor Double Mul Issue Sep 2, 2025
Comment on lines 189 to 190
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is handled here right?

if routed_scaling_factor != 1.0:
topk_weights = topk_weights * routed_scaling_factor

Why does main not apply the routed_scaling_factor in the fp16 case? Is there an accuracy issue we need to handle?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like it's to prevent fp16 overflow issue for DeepSeek V2: #13232

by multiplying by routed scaling factor in all cases, #23123 might have introduced an accuracy regression for V2 as well, we should check this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, so I changed another way to fix since deepseek v2 has a lot of logic for routed_scaling_factor outside

Signed-off-by: yewentao256 <[email protected]>
@yewentao256 yewentao256 added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 2, 2025
Copy link
Collaborator

@sarckk sarckk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the fix

I think we also need this fix for dots1 and glm4_moe? both use use_grouped_topk=True so they should be affected by the double mul issue as well

router_logits=router_logits) * self.routed_scaling_factor

router_logits=router_logits) * self.routed_scaling_factor

Signed-off-by: yewentao256 <[email protected]>
@yewentao256
Copy link
Member Author

thanks for the fix

I think we also need this fix for dots1 and glm4_moe? both use use_grouped_topk=True so they should be affected by the double mul issue as well

router_logits=router_logits) * self.routed_scaling_factor

router_logits=router_logits) * self.routed_scaling_factor

@sarckk Nice find! Feel free to have a pr towards this.

Copy link
Collaborator

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix.

@houseroad houseroad enabled auto-merge (squash) September 2, 2025 20:03
@facebook-github-bot
Copy link

@sarckk has imported this pull request. If you are a Meta employee, you can view this in D81519055.

@houseroad houseroad merged commit 930a241 into vllm-project:main Sep 2, 2025
42 checks passed
@yewentao256 yewentao256 deleted the wye-fix-routed-scaling-factor-double-mul branch September 2, 2025 22:23
@WoosukKwon
Copy link
Collaborator

WoosukKwon commented Sep 2, 2025

@yewentao256 Thanks for the fix. BTW, can you please add a test in our CI? I think we've seen similar accuracy issues 5+ times on DeepSeek...

@houseroad
Copy link
Collaborator

Besiding running full deepseek model, is it possible to run a smaller model to detect such issue, like DeepSeek v2, etc.

845473182 pushed a commit to 845473182/vllm that referenced this pull request Sep 3, 2025
* 'main' of https://github.com/845473182/vllm: (457 commits)
  [BugFix] Fix routed_scaling_factor double mul for dots1 and glm4 MoE models (vllm-project#24132)
  [Misc] Add check for dual_chunk_attention (vllm-project#24070)
  [Doc]: fix typos in Python comments (vllm-project#24115)
  [Doc]: fix typos in Python comments (vllm-project#24093)
  [Compile] Fix Compile Warning for `w4a8_mm_entry.cu` (vllm-project#23660)
  fix some typos (vllm-project#24071)
  [V1] Wrapper which plumbs request-level logits processors into vLLM batch-level logits processing (vllm-project#23656)
  Upgrade xgrammar to 0.1.23 (vllm-project#22988)
  Update release pipeline post PyTorch 2.8.0 update (vllm-project#24073)
  [XPU] Fix the bug of LoRA logits on the XPU platform (vllm-project#24081)
  [CI/Build] Disable SiluMul NVFP4 quant fusion tests (vllm-project#24121)
  [Bug] R1 Accuracy: Fix `routed_scaling_factor` Double Mul Issue (vllm-project#24119)
  [AMD][Kernel][Bugfix] Cast offsets tensor bn to tl.int64 to avoid GPU segfault (vllm-project#23692)
  [CI] Enable all hf transformers baselines in test_hybrid (vllm-project#23936)
  [Log] Only Print Profiler Results on Rank 0 (vllm-project#23370)
  Fix weights loading for Apertus (vllm-project#24100)
  [Metrics] Deprecate TPOT in favor of ITL (vllm-project#24110)
  [Bugfix] Fix packed_factor missing attribute error (vllm-project#23902)
  Run ruff format on a few files. (vllm-project#24075)
  [Bugfix] Fix transform_config parsing in Compressed Tensors (vllm-project#23945)
  ...
eicherseiji pushed a commit to eicherseiji/vllm that referenced this pull request Sep 9, 2025
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek Related to DeepSeek models ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: R1 Accuracy Issue in Main for deepep_high_througput

6 participants