Skip to content

Commit b7e4da0

Browse files
authored
Fix zamba2 rotary embedding call when use_mem_rope is False (#44551)
* only call rotary_emb when config.use_mem_rope is True * add fix in modular
1 parent 5a098a1 commit b7e4da0

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

src/transformers/models/zamba2/modeling_zamba2.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1339,7 +1339,12 @@ def forward(
13391339
past_key_values=past_key_values,
13401340
position_ids=position_ids,
13411341
)
1342-
position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
1342+
1343+
# create position embeddings to be shared across the decoder layers
1344+
if self.config.use_mem_rope:
1345+
position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
1346+
else:
1347+
position_embeddings = None
13431348

13441349
all_hidden_states = () if output_hidden_states else None
13451350
all_self_attns = () if output_attentions else None

src/transformers/models/zamba2/modular_zamba2.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1057,7 +1057,12 @@ def forward(
10571057
past_key_values=past_key_values,
10581058
position_ids=position_ids,
10591059
)
1060-
position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
1060+
1061+
# create position embeddings to be shared across the decoder layers
1062+
if self.config.use_mem_rope:
1063+
position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
1064+
else:
1065+
position_embeddings = None
10611066

10621067
all_hidden_states = () if output_hidden_states else None
10631068
all_self_attns = () if output_attentions else None

0 commit comments

Comments
 (0)