Skip to content

Commit c215523

Browse files
authored
add back the position ids (#32554)
* add back the position ids * fix failing test
1 parent f3c8b18 commit c215523

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -818,6 +818,7 @@ def get_decoder(self):
818818
def forward(
819819
self,
820820
input_ids: Optional[torch.LongTensor] = None,
821+
position_ids: Optional[torch.LongTensor] = None,
821822
cache_position: Optional[torch.LongTensor] = None,
822823
attention_mask: Optional[torch.Tensor] = None,
823824
inputs_embeds: Optional[torch.FloatTensor] = None,
@@ -858,6 +859,7 @@ def forward(
858859
output_hidden_states = True
859860
outputs = self.model(
860861
input_ids=input_ids,
862+
position_ids=position_ids,
861863
cache_position=cache_position,
862864
attention_mask=attention_mask,
863865
inputs_embeds=inputs_embeds,
@@ -913,13 +915,17 @@ def prepare_inputs_for_generation(
913915
if past_length > 0:
914916
position_ids = position_ids[:, past_length:]
915917

916-
if inputs_embeds is not None:
917-
model_inputs = {"inputs_embeds": inputs_embeds[:, past_length:]}
918-
else:
919-
model_inputs = {"input_ids": input_ids[:, past_length:].contiguous()}
918+
if inputs_embeds is not None: # Exception 1
919+
input_ids = input_ids[:, -cache_position.shape[0] :]
920+
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
921+
input_ids = input_ids[:, cache_position]
920922

921-
if cache_position is not None:
922-
cache_position = cache_position[-position_ids.shape[1] :]
923+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
924+
if inputs_embeds is not None and cache_position[0] == 0:
925+
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
926+
else:
927+
# The clone here is for the same reason as for `position_ids`.
928+
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
923929

924930
model_inputs.update(
925931
{

0 commit comments

Comments
 (0)