diff --git a/src/transformers/models/llama4/modeling_llama4.py b/src/transformers/models/llama4/modeling_llama4.py index 38eb3ce8eb82..059011629586 100644 --- a/src/transformers/models/llama4/modeling_llama4.py +++ b/src/transformers/models/llama4/modeling_llama4.py @@ -158,7 +158,7 @@ def forward(self, hidden_states): hidden_states = hidden_states.reshape(-1, self.hidden_dim) router_scores, router_logits = self.router(hidden_states) routed_in = hidden_states.repeat(router_scores.shape[1], 1) - routed_in = routed_in * router_scores.reshape(-1, 1) + routed_in = routed_in * router_scores.transpose(0, 1).reshape(-1, 1) routed_out = self.experts(routed_in) out = self.shared_expert(hidden_states) out.add_(routed_out.reshape(router_scores.shape[1], -1, routed_out.shape[-1]).sum(dim=0))