diff --git a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py index d6cb998ba4a5..cd293399b5ac 100644 --- a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +++ b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py @@ -818,6 +818,7 @@ def get_decoder(self): def forward( self, input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, @@ -858,6 +859,7 @@ def forward( output_hidden_states = True outputs = self.model( input_ids=input_ids, + position_ids=position_ids, cache_position=cache_position, attention_mask=attention_mask, inputs_embeds=inputs_embeds, @@ -913,13 +915,17 @@ def prepare_inputs_for_generation( if past_length > 0: position_ids = position_ids[:, past_length:] - if inputs_embeds is not None: - model_inputs = {"inputs_embeds": inputs_embeds[:, past_length:]} - else: - model_inputs = {"input_ids": input_ids[:, past_length:].contiguous()} + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] - if cache_position is not None: - cache_position = cache_position[-position_ids.shape[1] :] + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} + else: + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} model_inputs.update( {