Skip to content

Commit 67f13fd

Browse files
authored
Merge branch 'main' into fix_biogpt_test1
2 parents a407b24 + 52e9d05 commit 67f13fd

File tree

5 files changed

+21
-12
lines changed

5 files changed

+21
-12
lines changed

src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1705,6 +1705,8 @@ def prepare_inputs_for_generation(
17051705
past_key_values=None,
17061706
image_grid_thw=None,
17071707
video_grid_thw=None,
1708+
use_cache=True,
1709+
is_first_iteration=False,
17081710
# Intentionally ignore position ids to force custom cache logic
17091711
position_ids=None,
17101712
**kwargs,
@@ -1717,6 +1719,8 @@ def prepare_inputs_for_generation(
17171719
past_key_values=past_key_values,
17181720
image_grid_thw=image_grid_thw,
17191721
video_grid_thw=video_grid_thw,
1722+
use_cache=use_cache,
1723+
is_first_iteration=is_first_iteration,
17201724
**kwargs,
17211725
)
17221726

@@ -1732,7 +1736,7 @@ def prepare_inputs_for_generation(
17321736
mm_token_type_ids=model_inputs.get("mm_token_type_ids"),
17331737
)
17341738

1735-
if model_inputs["cache_position"][0] != 0:
1739+
if not is_first_iteration and use_cache:
17361740
model_inputs["pixel_values"] = None
17371741
model_inputs["pixel_values_videos"] = None
17381742
model_inputs["mm_token_type_ids"] = None

src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1393,6 +1393,8 @@ def prepare_inputs_for_generation(
13931393
past_key_values=None,
13941394
image_grid_thw=None,
13951395
video_grid_thw=None,
1396+
use_cache=True,
1397+
is_first_iteration=False,
13961398
# Intentionally ignore position ids to force custom cache logic
13971399
position_ids=None,
13981400
**kwargs,
@@ -1405,6 +1407,8 @@ def prepare_inputs_for_generation(
14051407
past_key_values=past_key_values,
14061408
image_grid_thw=image_grid_thw,
14071409
video_grid_thw=video_grid_thw,
1410+
use_cache=use_cache,
1411+
is_first_iteration=is_first_iteration,
14081412
**kwargs,
14091413
)
14101414

@@ -1420,7 +1424,7 @@ def prepare_inputs_for_generation(
14201424
mm_token_type_ids=model_inputs.get("mm_token_type_ids"),
14211425
)
14221426

1423-
if model_inputs["cache_position"][0] != 0:
1427+
if not is_first_iteration and use_cache:
14241428
model_inputs["pixel_values"] = None
14251429
model_inputs["pixel_values_videos"] = None
14261430
model_inputs["mm_token_type_ids"] = None

src/transformers/models/gpt_oss/modeling_gpt_oss.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,8 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig
8888
8989
Args:
9090
hidden_states (torch.Tensor): (batch_size, seq_len, hidden_size)
91-
selected_experts (torch.Tensor): (batch_size * token_num, top_k)
92-
routing_weights (torch.Tensor): (batch_size * token_num, num_experts)
91+
selected_experts (torch.Tensor): (batch_size * seq_len, top_k)
92+
routing_weights (torch.Tensor): (batch_size * seq_len, top_k)
9393
Returns:
9494
torch.Tensor
9595
"""
@@ -159,8 +159,8 @@ def __init__(self, config):
159159

160160
def forward(self, hidden_states):
161161
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
162-
router_logits = F.linear(hidden_states, self.weight, self.bias) # (seq_len, num_experts)
163-
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k)
162+
router_logits = F.linear(hidden_states, self.weight, self.bias) # (num_tokens, num_experts)
163+
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (num_tokens, top_k)
164164
router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
165165
router_scores = router_top_value
166166
return router_logits, router_scores, router_indices
@@ -434,7 +434,7 @@ class GptOssPreTrainedModel(PreTrainedModel):
434434
_skip_keys_device_placement = ["past_key_values"]
435435
_supports_flash_attn = True
436436
_supports_sdpa = False
437-
_supports_flex_attn = False
437+
_supports_flex_attn = True
438438

439439
_can_compile_fullgraph = True
440440
_supports_attention_backend = True

src/transformers/models/gpt_oss/modular_gpt_oss.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,8 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig
8686
8787
Args:
8888
hidden_states (torch.Tensor): (batch_size, seq_len, hidden_size)
89-
selected_experts (torch.Tensor): (batch_size * token_num, top_k)
90-
routing_weights (torch.Tensor): (batch_size * token_num, num_experts)
89+
selected_experts (torch.Tensor): (batch_size * seq_len, top_k)
90+
routing_weights (torch.Tensor): (batch_size * seq_len, top_k)
9191
Returns:
9292
torch.Tensor
9393
"""
@@ -157,8 +157,8 @@ def __init__(self, config):
157157

158158
def forward(self, hidden_states):
159159
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
160-
router_logits = F.linear(hidden_states, self.weight, self.bias) # (seq_len, num_experts)
161-
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k)
160+
router_logits = F.linear(hidden_states, self.weight, self.bias) # (num_tokens, num_experts)
161+
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (num_tokens, top_k)
162162
router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
163163
router_scores = router_top_value
164164
return router_logits, router_scores, router_indices
@@ -354,7 +354,6 @@ def forward(
354354
class GptOssPreTrainedModel(LlamaPreTrainedModel):
355355
_keep_in_fp32_modules = ["post_attention_layernorm", "input_layernorm", "norm"]
356356
_supports_sdpa = False
357-
_supports_flex_attn = False
358357
_can_record_outputs = {
359358
"router_logits": OutputRecorder(GptOssTopKRouter, index=0),
360359
"hidden_states": GptOssDecoderLayer,

tests/models/ernie4_5_vl_moe/test_modeling_ernie4_5_vl_moe.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,7 @@ def load_model(self, dtype, attn_implementation="sdpa"):
313313
device_map="auto",
314314
dtype=dtype,
315315
attn_implementation=attn_implementation,
316+
experts_implementation="eager",
316317
revision="refs/pr/10",
317318
)
318319

@@ -549,6 +550,7 @@ def load_model(self, dtype, attn_implementation="sdpa"):
549550
device_map="auto",
550551
dtype=dtype,
551552
attn_implementation=attn_implementation,
553+
experts_implementation="eager",
552554
)
553555

554556
def test_small_model_integration_test(self):

0 commit comments

Comments
 (0)