Skip to content

Commit 410435a

Browse files
authored
Merge pull request #21 from huggingface/ed-fix-modeling
Update modelling to work with new checkpoints, exposes output_router_logits
2 parents 68fd833 + 968238c commit 410435a

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

src/transformers/models/openai_moe/modeling_openai_moe.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig
102102
_, token_idx = torch.where(expert_mask[expert_idx[0]])
103103
current_state = hidden_states[token_idx]
104104
gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx]
105-
gate, up = gate_up.chunk(2, dim=-1)
105+
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
106106
glu = gate * torch.sigmoid(gate * self.alpha)
107107
gated_output = (up + 1) * glu
108108
out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx]
@@ -113,7 +113,7 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig
113113
hidden_states = hidden_states.repeat(num_experts, 1)
114114
hidden_states = hidden_states.view(num_experts, -1, self.hidden_size)
115115
gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :]
116-
gate, up = gate_up.chunk(2, dim=-1)
116+
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
117117
glu = gate * torch.sigmoid(gate * self.alpha)
118118
next_states = torch.bmm(((up + 1) * glu), self.down_proj)
119119
next_states = next_states + self.down_proj_bias[..., None, :]
@@ -666,7 +666,9 @@ def forward(
666666
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
667667
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
668668
```"""
669-
669+
output_router_logits = (
670+
output_router_logits if output_router_logits is not None else self.config.output_router_logits
671+
)
670672
outputs: MoeModelOutputWithPast = self.model(
671673
input_ids=input_ids,
672674
attention_mask=attention_mask,

src/transformers/models/openai_moe/modular_openai_moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig
9696
gate_up = (
9797
current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx]
9898
) # (num_tokens, 2 * interm_dim)
99-
gate, up = gate_up.chunk(2, dim=-1) # (num_tokens, interm_dim)
99+
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
100100
glu = gate * torch.sigmoid(gate * self.alpha) # (num_tokens, interm_dim)
101101
gated_output = (up + 1) * glu # (num_tokens, interm_dim)
102102
out = (
@@ -109,7 +109,7 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig
109109
hidden_states = hidden_states.repeat(num_experts, 1)
110110
hidden_states = hidden_states.view(num_experts, -1, self.hidden_size)
111111
gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :]
112-
gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors
112+
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
113113
glu = gate * torch.sigmoid(gate * self.alpha)
114114
next_states = torch.bmm(((up + 1) * glu), self.down_proj)
115115
next_states = next_states + self.down_proj_bias[..., None, :]

0 commit comments

Comments
 (0)