Skip to content

Commit 973e94d

Browse files
jinzhen-linmgoin
authored andcommitted
[Bugfix] fix deepseek fp16 scale bug (vllm-project#14809)
Signed-off-by: Jinzhen Lin <[email protected]> Co-authored-by: mgoin <[email protected]>
1 parent 934cf70 commit 973e94d

File tree

1 file changed

+24
-10
lines changed

1 file changed

+24
-10
lines changed

vllm/model_executor/models/deepseek_v2.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -161,14 +161,16 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
161161
hidden_states=hidden_states,
162162
router_logits=router_logits) * self.routed_scaling_factor
163163
else:
164-
# This is a special case to avoid FP16 overflow
164+
# Fix FP16 overflow
165+
# See DeepseekV2DecoderLayer for more details.
165166
final_hidden_states = self.experts(hidden_states=hidden_states,
166167
router_logits=router_logits)
167168
if shared_output is not None:
168169
if hidden_states.dtype != torch.float16:
169170
final_hidden_states = final_hidden_states + shared_output
170171
else:
171-
# This is a special case to avoid FP16 overflow
172+
# Fix FP16 overflow
173+
# See DeepseekV2DecoderLayer for more details.
172174
final_hidden_states = final_hidden_states + shared_output \
173175
* (1. / self.routed_scaling_factor)
174176
if self.tp_size > 1:
@@ -500,6 +502,7 @@ def __init__(
500502
# DecoderLayers are created with `make_layers` which passes the prefix
501503
# with the layer's index.
502504
layer_idx = int(prefix.split(sep='.')[-1])
505+
self.layer_idx = layer_idx
503506
if model_config.use_mla:
504507
attn_cls = DeepseekV2MLAAttention
505508
else:
@@ -562,19 +565,30 @@ def forward(
562565
hidden_states=hidden_states,
563566
)
564567

565-
# Fully Connected
566-
if isinstance(self.mlp, DeepseekV2MoE) and \
567-
hidden_states.dtype == torch.float16:
568-
# This is a special case to avoid FP16 overflow
568+
if hidden_states.dtype == torch.float16:
569+
# Fix FP16 overflow
570+
# We scale both hidden_states and residual before
571+
# rmsnorm, and rmsnorm result would not affect by scale.
569572
hidden_states *= 1. / self.routed_scaling_factor
573+
if self.layer_idx == 0:
574+
# The residual is shared by all layers, we only scale it on
575+
# first layer.
576+
residual *= 1. / self.routed_scaling_factor
577+
578+
# Fully Connected
570579
hidden_states, residual = self.post_attention_layernorm(
571580
hidden_states, residual)
572581
hidden_states = self.mlp(hidden_states)
573-
if isinstance(self.mlp, DeepseekV2MLP) and \
574-
hidden_states.dtype == torch.float16:
575-
# This is a special case to avoid FP16 overflow
582+
583+
if isinstance(self.mlp,
584+
DeepseekV2MLP) and hidden_states.dtype == torch.float16:
585+
# Fix FP16 overflow
586+
# Scaling the DeepseekV2MLP output, it is the input of
587+
# input_layernorm of next decoder layer.
588+
# The scaling of DeepseekV2MOE output would be done in the forward
589+
# of DeepseekV2MOE
576590
hidden_states *= 1. / self.routed_scaling_factor
577-
residual *= 1. / self.routed_scaling_factor
591+
578592
return hidden_states, residual
579593

580594

0 commit comments

Comments
 (0)