File tree Expand file tree Collapse file tree 2 files changed +3
-3
lines changed Expand file tree Collapse file tree 2 files changed +3
-3
lines changed Original file line number Diff line number Diff line change @@ -585,6 +585,7 @@ def apply(
585585 fused_marlin_moe )
586586
587587 # The input must currently be float16
588+ orig_dtype = x .dtype
588589 x = x .half ()
589590
590591 topk_weights , topk_ids = FusedMoE .select_experts (
@@ -610,4 +611,4 @@ def apply(
610611 topk_ids ,
611612 w1_scale = layer .w13_scales ,
612613 w2_scale = layer .w2_scales ,
613- )
614+ ). to ( orig_dtype )
Original file line number Diff line number Diff line change @@ -95,12 +95,11 @@ def __init__(self,
9595 def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
9696 # NOTE: hidden_states can have either 1D or 2D shape.
9797 orig_shape = hidden_states .shape
98- orig_dtype = hidden_states .dtype
9998 hidden_states = hidden_states .view (- 1 , self .hidden_size )
10099 # router_logits: (num_tokens, n_experts)
101100 router_logits , _ = self .gate (hidden_states )
102101 final_hidden_states = self .experts (hidden_states , router_logits )
103- return final_hidden_states .view (orig_shape ). to ( orig_dtype )
102+ return final_hidden_states .view (orig_shape )
104103
105104
106105class MixtralAttention (nn .Module ):
You can’t perform that action at this time.
0 commit comments