@@ -161,14 +161,16 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
161161 hidden_states = hidden_states ,
162162 router_logits = router_logits ) * self .routed_scaling_factor
163163 else :
164- # This is a special case to avoid FP16 overflow
164+ # Fix FP16 overflow
165+ # See DeepseekV2DecoderLayer for more details.
165166 final_hidden_states = self .experts (hidden_states = hidden_states ,
166167 router_logits = router_logits )
167168 if shared_output is not None :
168169 if hidden_states .dtype != torch .float16 :
169170 final_hidden_states = final_hidden_states + shared_output
170171 else :
171- # This is a special case to avoid FP16 overflow
172+ # Fix FP16 overflow
173+ # See DeepseekV2DecoderLayer for more details.
172174 final_hidden_states = final_hidden_states + shared_output \
173175 * (1. / self .routed_scaling_factor )
174176 if self .tp_size > 1 :
@@ -500,6 +502,7 @@ def __init__(
500502 # DecoderLayers are created with `make_layers` which passes the prefix
501503 # with the layer's index.
502504 layer_idx = int (prefix .split (sep = '.' )[- 1 ])
505+ self .layer_idx = layer_idx
503506 if model_config .use_mla :
504507 attn_cls = DeepseekV2MLAAttention
505508 else :
@@ -562,19 +565,30 @@ def forward(
562565 hidden_states = hidden_states ,
563566 )
564567
565- # Fully Connected
566- if isinstance ( self . mlp , DeepseekV2MoE ) and \
567- hidden_states . dtype == torch . float16 :
568- # This is a special case to avoid FP16 overflow
568+ if hidden_states . dtype == torch . float16 :
569+ # Fix FP16 overflow
570+ # We scale both hidden_states and residual before
571+ # rmsnorm, and rmsnorm result would not affect by scale.
569572 hidden_states *= 1. / self .routed_scaling_factor
573+ if self .layer_idx == 0 :
574+ # The residual is shared by all layers, we only scale it on
575+ # first layer.
576+ residual *= 1. / self .routed_scaling_factor
577+
578+ # Fully Connected
570579 hidden_states , residual = self .post_attention_layernorm (
571580 hidden_states , residual )
572581 hidden_states = self .mlp (hidden_states )
573- if isinstance (self .mlp , DeepseekV2MLP ) and \
574- hidden_states .dtype == torch .float16 :
575- # This is a special case to avoid FP16 overflow
582+
583+ if isinstance (self .mlp ,
584+ DeepseekV2MLP ) and hidden_states .dtype == torch .float16 :
585+ # Fix FP16 overflow
586+ # Scaling the DeepseekV2MLP output, it is the input of
587+ # input_layernorm of next decoder layer.
588+ # The scaling of DeepseekV2MOE output would be done in the forward
589+ # of DeepseekV2MOE
576590 hidden_states *= 1. / self .routed_scaling_factor
577- residual *= 1. / self . routed_scaling_factor
591+
578592 return hidden_states , residual
579593
580594
0 commit comments