File tree Expand file tree Collapse file tree 2 files changed +12
-2
lines changed
src/transformers/models/zamba2 Expand file tree Collapse file tree 2 files changed +12
-2
lines changed Original file line number Diff line number Diff line change @@ -1339,7 +1339,12 @@ def forward(
13391339 past_key_values = past_key_values ,
13401340 position_ids = position_ids ,
13411341 )
1342- position_embeddings = self .rotary_emb (hidden_states , position_ids = position_ids )
1342+
1343+ # create position embeddings to be shared across the decoder layers
1344+ if self .config .use_mem_rope :
1345+ position_embeddings = self .rotary_emb (hidden_states , position_ids = position_ids )
1346+ else :
1347+ position_embeddings = None
13431348
13441349 all_hidden_states = () if output_hidden_states else None
13451350 all_self_attns = () if output_attentions else None
Original file line number Diff line number Diff line change @@ -1057,7 +1057,12 @@ def forward(
10571057 past_key_values = past_key_values ,
10581058 position_ids = position_ids ,
10591059 )
1060- position_embeddings = self .rotary_emb (hidden_states , position_ids = position_ids )
1060+
1061+ # create position embeddings to be shared across the decoder layers
1062+ if self .config .use_mem_rope :
1063+ position_embeddings = self .rotary_emb (hidden_states , position_ids = position_ids )
1064+ else :
1065+ position_embeddings = None
10611066
10621067 all_hidden_states = () if output_hidden_states else None
10631068 all_self_attns = () if output_attentions else None
You can’t perform that action at this time.
0 commit comments