@@ -424,15 +424,26 @@ def _collate_prompt_completion(self, examples: list[dict[str, Any]]) -> dict[str
424424 input_ids = torch .cat ((prompt_ids , completion_ids ), dim = 1 )
425425 attention_mask = torch .cat ((prompt_mask , completion_mask ), dim = 1 )
426426 completion_mask = torch .cat ((torch .zeros_like (prompt_mask ), completion_mask ), dim = 1 )
427+ if "token_type_ids" in processed_prompts : # special case for Gemma
428+ prompt_token_type_ids = processed_prompts ["token_type_ids" ]
429+ completion_token_type_ids = processed_completions ["token_type_ids" ]
430+ token_type_ids = torch .cat ((prompt_token_type_ids , completion_token_type_ids ), dim = 1 )
427431
428432 # Flush left to reduce padding
429- attention_mask , input_ids , completion_mask = flush_left (attention_mask , input_ids , completion_mask )
433+ if "token_type_ids" in processed_prompts :
434+ attention_mask , input_ids , completion_mask , token_type_ids = flush_left (
435+ attention_mask , input_ids , completion_mask , token_type_ids
436+ )
437+ else :
438+ attention_mask , input_ids , completion_mask = flush_left (attention_mask , input_ids , completion_mask )
430439
431440 # Truncate if necessary
432441 if self .max_length is not None :
433442 input_ids = input_ids [:, : self .max_length ]
434443 attention_mask = attention_mask [:, : self .max_length ]
435444 completion_mask = completion_mask [:, : self .max_length ]
445+ if "token_type_ids" in processed_prompts :
446+ token_type_ids = token_type_ids [:, : self .max_length ]
436447
437448 # Create labels and mask padding tokens
438449 labels = input_ids .clone ()
@@ -445,6 +456,8 @@ def _collate_prompt_completion(self, examples: list[dict[str, Any]]) -> dict[str
445456 output ["input_ids" ] = input_ids
446457 output ["attention_mask" ] = attention_mask
447458 output ["labels" ] = labels
459+ if "token_type_ids" in processed_prompts :
460+ output ["token_type_ids" ] = token_type_ids
448461 return output
449462
450463
0 commit comments