@@ -159,6 +159,7 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
159159
160160 self .experts = SharedFusedMoE (
161161 shared_experts = self .shared_expert ,
162+ gate = self .gate ,
162163 num_experts = self .n_routed_experts ,
163164 top_k = config .num_experts_per_tok ,
164165 hidden_size = config .hidden_size ,
@@ -181,11 +182,17 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
181182 if self .is_sequence_parallel :
182183 hidden_states = sequence_parallel_chunk (hidden_states )
183184
184- # router_logits: (num_tokens, n_experts)
185- router_logits , _ = self .gate (hidden_states )
186- final_hidden_states = self .experts (
187- hidden_states = hidden_states , router_logits = router_logits
188- )
185+ if self .experts .is_internal_router :
186+ # In this case, the gate/router runs inside the FusedMoE class
187+ final_hidden_states = self .experts (
188+ hidden_states = hidden_states , router_logits = hidden_states
189+ )
190+ else :
191+ # router_logits: (num_tokens, n_experts)
192+ router_logits , _ = self .gate (hidden_states )
193+ final_hidden_states = self .experts (
194+ hidden_states = hidden_states , router_logits = router_logits
195+ )
189196
190197 if self .shared_expert is not None :
191198 final_hidden_states = final_hidden_states [0 ] + final_hidden_states [1 ]
0 commit comments