@@ -102,7 +102,7 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig
102102 _ , token_idx = torch .where (expert_mask [expert_idx [0 ]])
103103 current_state = hidden_states [token_idx ]
104104 gate_up = current_state @ self .gate_up_proj [expert_idx ] + self .gate_up_proj_bias [expert_idx ]
105- gate , up = gate_up . chunk ( 2 , dim = - 1 )
105+ gate , up = gate_up [..., :: 2 ], gate_up [..., 1 :: 2 ]
106106 glu = gate * torch .sigmoid (gate * self .alpha )
107107 gated_output = (up + 1 ) * glu
108108 out = gated_output @ self .down_proj [expert_idx ] + self .down_proj_bias [expert_idx ]
@@ -113,7 +113,7 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig
113113 hidden_states = hidden_states .repeat (num_experts , 1 )
114114 hidden_states = hidden_states .view (num_experts , - 1 , self .hidden_size )
115115 gate_up = torch .bmm (hidden_states , self .gate_up_proj ) + self .gate_up_proj_bias [..., None , :]
116- gate , up = gate_up . chunk ( 2 , dim = - 1 )
116+ gate , up = gate_up [..., :: 2 ], gate_up [..., 1 :: 2 ]
117117 glu = gate * torch .sigmoid (gate * self .alpha )
118118 next_states = torch .bmm (((up + 1 ) * glu ), self .down_proj )
119119 next_states = next_states + self .down_proj_bias [..., None , :]
@@ -666,7 +666,9 @@ def forward(
666666 >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
667667 "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
668668 ```"""
669-
669+ output_router_logits = (
670+ output_router_logits if output_router_logits is not None else self .config .output_router_logits
671+ )
670672 outputs : MoeModelOutputWithPast = self .model (
671673 input_ids = input_ids ,
672674 attention_mask = attention_mask ,
0 commit comments