@@ -155,11 +155,21 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
155155 shared_output = self .shared_experts (hidden_states )
156156 # router_logits: (num_tokens, n_experts)
157157 router_logits , _ = self .gate (hidden_states )
158- final_hidden_states = self .experts (
159- hidden_states = hidden_states ,
160- router_logits = router_logits ) * self .routed_scaling_factor
158+ if hidden_states .dtype != torch .float16 :
159+ final_hidden_states = self .experts (
160+ hidden_states = hidden_states ,
161+ router_logits = router_logits ) * self .routed_scaling_factor
162+ else :
163+ # This is a special case to avoid FP16 overflow
164+ final_hidden_states = self .experts (hidden_states = hidden_states ,
165+ router_logits = router_logits )
161166 if shared_output is not None :
162- final_hidden_states = final_hidden_states + shared_output
167+ if hidden_states .dtype != torch .float16 :
168+ final_hidden_states = final_hidden_states + shared_output
169+ else :
170+ # This is a special case to avoid FP16 overflow
171+ final_hidden_states = final_hidden_states + shared_output \
172+ * (1. / self .routed_scaling_factor )
163173 if self .tp_size > 1 :
164174 final_hidden_states = tensor_model_parallel_all_reduce (
165175 final_hidden_states )
@@ -531,6 +541,7 @@ def __init__(
531541 eps = config .rms_norm_eps )
532542 self .post_attention_layernorm = RMSNorm (config .hidden_size ,
533543 eps = config .rms_norm_eps )
544+ self .routed_scaling_factor = config .routed_scaling_factor
534545
535546 def forward (
536547 self ,
@@ -551,9 +562,18 @@ def forward(
551562 )
552563
553564 # Fully Connected
565+ if isinstance (self .mlp , DeepseekV2MoE ) and \
566+ hidden_states .dtype == torch .float16 :
567+ # This is a special case to avoid FP16 overflow
568+ hidden_states *= 1. / self .routed_scaling_factor
554569 hidden_states , residual = self .post_attention_layernorm (
555570 hidden_states , residual )
556571 hidden_states = self .mlp (hidden_states )
572+ if isinstance (self .mlp , DeepseekV2MLP ) and \
573+ hidden_states .dtype == torch .float16 :
574+ # This is a special case to avoid FP16 overflow
575+ hidden_states *= 1. / self .routed_scaling_factor
576+ residual *= 1. / self .routed_scaling_factor
557577 return hidden_states , residual
558578
559579
0 commit comments