@@ -818,6 +818,7 @@ def get_decoder(self):
818818 def forward (
819819 self ,
820820 input_ids : Optional [torch .LongTensor ] = None ,
821+ position_ids : Optional [torch .LongTensor ] = None ,
821822 cache_position : Optional [torch .LongTensor ] = None ,
822823 attention_mask : Optional [torch .Tensor ] = None ,
823824 inputs_embeds : Optional [torch .FloatTensor ] = None ,
@@ -858,6 +859,7 @@ def forward(
858859 output_hidden_states = True
859860 outputs = self .model (
860861 input_ids = input_ids ,
862+ position_ids = position_ids ,
861863 cache_position = cache_position ,
862864 attention_mask = attention_mask ,
863865 inputs_embeds = inputs_embeds ,
@@ -913,13 +915,17 @@ def prepare_inputs_for_generation(
913915 if past_length > 0 :
914916 position_ids = position_ids [:, past_length :]
915917
916- if inputs_embeds is not None :
917- model_inputs = { "inputs_embeds" : inputs_embeds [:, past_length :]}
918- else :
919- model_inputs = { " input_ids" : input_ids [:, past_length :]. contiguous ()}
918+ if inputs_embeds is not None : # Exception 1
919+ input_ids = input_ids [:, - cache_position . shape [ 0 ] :]
920+ elif input_ids . shape [ 1 ] != cache_position . shape [ 0 ]: # Default case (the "else", a no op, is Exception 2)
921+ input_ids = input_ids [:, cache_position ]
920922
921- if cache_position is not None :
922- cache_position = cache_position [- position_ids .shape [1 ] :]
923+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
924+ if inputs_embeds is not None and cache_position [0 ] == 0 :
925+ model_inputs = {"inputs_embeds" : inputs_embeds , "input_ids" : None }
926+ else :
927+ # The clone here is for the same reason as for `position_ids`.
928+ model_inputs = {"input_ids" : input_ids .clone (memory_format = torch .contiguous_format ), "inputs_embeds" : None }
923929
924930 model_inputs .update (
925931 {
0 commit comments