From 185c1cdbb758bd07af6aa195c9014c0c3ed5665a Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 9 Aug 2024 09:49:11 +0200 Subject: [PATCH 1/2] add back the position ids --- .../models/recurrent_gemma/modeling_recurrent_gemma.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py index d6cb998ba4a5..028550be9e76 100644 --- a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +++ b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py @@ -818,6 +818,7 @@ def get_decoder(self): def forward( self, input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, @@ -858,6 +859,7 @@ def forward( output_hidden_states = True outputs = self.model( input_ids=input_ids, + position_ids=position_ids, cache_position=cache_position, attention_mask=attention_mask, inputs_embeds=inputs_embeds, From 209fcccbba41575fd0ece8b3ff20930dbac34a80 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 9 Aug 2024 11:13:08 +0200 Subject: [PATCH 2/2] fix failing test --- .../recurrent_gemma/modeling_recurrent_gemma.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py index 028550be9e76..cd293399b5ac 100644 --- a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +++ b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py @@ -915,13 +915,17 @@ def prepare_inputs_for_generation( if past_length > 0: position_ids = position_ids[:, past_length:] - if inputs_embeds is not None: - model_inputs = {"inputs_embeds": inputs_embeds[:, past_length:]} + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: - model_inputs = {"input_ids": input_ids[:, past_length:].contiguous()} - - if cache_position is not None: - cache_position = cache_position[-position_ids.shape[1] :] + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} model_inputs.update( {