From 7799f2764d86da905095dd81834ca300466a53b4 Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Tue, 2 Sep 2025 12:15:52 -0700 Subject: [PATCH 1/3] fix routed_scaling_factor double mul issue Signed-off-by: yewentao256 --- vllm/model_executor/models/deepseek_v2.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 36c9427e474e..6155158ae9b4 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -186,15 +186,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - if hidden_states.dtype != torch.float16: - final_hidden_states = self.experts( - hidden_states=hidden_states, - router_logits=router_logits) * self.routed_scaling_factor - else: - # Fix FP16 overflow - # See DeepseekV2DecoderLayer for more details. - final_hidden_states = self.experts(hidden_states=hidden_states, - router_logits=router_logits) + final_hidden_states = self.experts(hidden_states=hidden_states, + router_logits=router_logits) if shared_output is not None: if hidden_states.dtype != torch.float16: final_hidden_states = final_hidden_states + shared_output From 583afad501447ca59ce006bad3e373f8a5880ae8 Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Tue, 2 Sep 2025 12:40:04 -0700 Subject: [PATCH 2/3] fix in another way Signed-off-by: yewentao256 --- vllm/model_executor/models/deepseek_v2.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 6155158ae9b4..aca8dbf66fb0 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -160,7 +160,7 @@ def __init__( topk_group=config.topk_group, prefix=f"{prefix}.experts", scoring_func=config.scoring_func, - routed_scaling_factor=self.routed_scaling_factor, + routed_scaling_factor=1.0, e_score_correction_bias=self.gate.e_score_correction_bias, enable_eplb=self.enable_eplb, num_redundant_experts=self.n_redundant_experts) @@ -186,8 +186,15 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts(hidden_states=hidden_states, - router_logits=router_logits) + if hidden_states.dtype != torch.float16: + final_hidden_states = self.experts( + hidden_states=hidden_states, + router_logits=router_logits) * self.routed_scaling_factor + else: + # Fix FP16 overflow + # See DeepseekV2DecoderLayer for more details. + final_hidden_states = self.experts(hidden_states=hidden_states, + router_logits=router_logits) if shared_output is not None: if hidden_states.dtype != torch.float16: final_hidden_states = final_hidden_states + shared_output From 3a80358601765aadf6d5636c31102d52065eb47d Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Tue, 2 Sep 2025 12:57:10 -0700 Subject: [PATCH 3/3] add comments Signed-off-by: yewentao256 --- vllm/model_executor/models/deepseek_v2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index aca8dbf66fb0..3a8eaf681733 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -160,6 +160,7 @@ def __init__( topk_group=config.topk_group, prefix=f"{prefix}.experts", scoring_func=config.scoring_func, + # we do scaling outside, set factor to 1.0 to avoid double mul routed_scaling_factor=1.0, e_score_correction_bias=self.gate.e_score_correction_bias, enable_eplb=self.enable_eplb,