1414
1515import inspect
1616import os
17- import re
1817import textwrap
1918from collections import defaultdict , deque
2019from contextlib import nullcontext
7170 shuffle_sequence_dict ,
7271 split_pixel_values_by_grid ,
7372 split_tensor_dict ,
74- truncate_with_protected_tokens ,
7573 unsplit_pixel_values_by_grid ,
7674)
7775
@@ -275,7 +273,7 @@ def __init__(
275273
276274 # Processing class
277275 if processing_class is None :
278- processing_class = AutoProcessor .from_pretrained (model .config ._name_or_path )
276+ processing_class = AutoProcessor .from_pretrained (model .config ._name_or_path , truncation_side = "left" )
279277
280278 # Handle pad token for processors or tokenizers
281279 if isinstance (processing_class , ProcessorMixin ):
@@ -291,10 +289,6 @@ def __init__(
291289 self .pad_token = tokenizer .pad_token
292290 self .pad_token_id = tokenizer .pad_token_id
293291 self .eos_token_id = tokenizer .eos_token_id
294- self .image_token = getattr (processing_class , "image_token" , None )
295- self .image_token_id = getattr (processing_class , "image_token_id" , None )
296- self .vision_start_token_id = getattr (model .config , "vision_start_token_id" , None )
297- self .vision_end_token_id = getattr (model .config , "vision_end_token_id" , None )
298292
299293 # Reward functions
300294 if not isinstance (reward_funcs , list ):
@@ -1092,58 +1086,12 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]):
10921086 maybe_apply_chat_template ({"prompt" : prompt }, self .processing_class )["prompt" ] for prompt in prompts
10931087 ]
10941088
1095- prompt_inputs = self .processing_class (
1096- text = prompts_text ,
1097- return_tensors = "pt" ,
1098- padding = True ,
1099- padding_side = "left" ,
1100- add_special_tokens = False ,
1101- ** kwargs ,
1102- )
1103- prompt_inputs = super ()._prepare_inputs (prompt_inputs )
1104- forward_kwargs = {k : v for k , v in prompt_inputs .items () if k not in ["input_ids" , "attention_mask" ]}
1105-
1106- if self .max_prompt_length is not None :
1107- prompt_ids , prompt_mask = prompt_inputs ["input_ids" ], prompt_inputs ["attention_mask" ]
1108- prompt_ids = [p [m ].tolist () for p , m in zip (prompt_ids , prompt_mask .bool ())]
1109-
1110- # If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens.
1111- # Then we decode those tokens back into text. We set `skip_special_tokens=False` because some special
1112- # tokens are needed for generation.
1113- protected = [self .image_token_id , self .vision_start_token_id , self .vision_end_token_id ]
1114- protected = [token for token in protected if token is not None ]
1115- prompt_ids = [truncate_with_protected_tokens (ids , self .max_prompt_length , protected ) for ids in prompt_ids ]
1116-
1117- prompts_text = self .processing_class .batch_decode (
1118- prompt_ids , skip_special_tokens = False , clean_up_tokenization_spaces = False
1119- )
1120-
1121- # The chat template sometimes inserts a single image token into the prompt text. However, when this text is
1122- # later tokenized, the single image token string is expanded into multiple image token IDs, depending on the
1123- # image size. Since we're detokenizing here, we may see repeated image tokens in the decoded text. We
1124- # collapse them back into a single token string to match the original chat template in case it originally
1125- # applies it. Otherwise, it assumes that the chat template uses only vision_start_token_id to indicate images
1126- # (e.g. Gemma 3) and removes all image_token instances and vision_end_token_id as well, leaving only
1127- # the vision_start_token_id (e.g. <start_of_image>).
1128- if self .image_token is not None :
1129- escaped_img_token = re .escape (self .image_token )
1130- # Search for the image token in the chat template
1131- if re .search (escaped_img_token , self .processing_class .chat_template ):
1132- prompts_text = [
1133- re .sub (rf"({ escaped_img_token } )+" , self .image_token , text ) for text in prompts_text
1134- ]
1135- else :
1136- # If the chat template doesn't use the image token, we remove all instances of it + vision_end_token_id
1137- if self .vision_end_token_id is not None :
1138- escaped_eoi_token = re .escape (
1139- self .processing_class .tokenizer .decode ([self .vision_end_token_id ])
1140- )
1141- prompts_text = [
1142- re .sub (rf"({ escaped_img_token } )+{ escaped_eoi_token } " , "" , text ) for text in prompts_text
1143- ]
1144- else :
1145- # If vision_end_token_id is None, just remove the image tokens
1146- prompts_text = [re .sub (rf"({ escaped_img_token } )+" , "" , text ) for text in prompts_text ]
1089+ if images is not None :
1090+ prompt_inputs = self .processing_class (text = prompts_text , padding = True , return_tensors = "pt" , ** kwargs )
1091+ prompt_inputs = super ()._prepare_inputs (prompt_inputs )
1092+ forward_kwargs = {k : v for k , v in prompt_inputs .items () if k not in ["input_ids" , "attention_mask" ]}
1093+ else :
1094+ forward_kwargs = {}
11471095
11481096 # Generate completions using either vLLM or regular generation
11491097 if self .use_vllm :
@@ -1185,6 +1133,7 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]):
11851133 top_k = - 1 if self .top_k is None else self .top_k ,
11861134 min_p = 0.0 if self .min_p is None else self .min_p ,
11871135 max_tokens = self .max_completion_length ,
1136+ truncate_prompt_tokens = self .max_prompt_length ,
11881137 guided_decoding_regex = self .guided_decoding_regex ,
11891138 generation_kwargs = self .args .generation_kwargs ,
11901139 )
@@ -1223,6 +1172,7 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]):
12231172 "top_k" : - 1 if self .top_k is None else self .top_k ,
12241173 "min_p" : 0.0 if self .min_p is None else self .min_p ,
12251174 "max_tokens" : self .max_completion_length ,
1175+ "truncate_prompt_tokens" : self .max_prompt_length ,
12261176 "guided_decoding" : guided_decoding ,
12271177 "logprobs" : 0 , # only return the logprob of the generated token
12281178 }
@@ -1319,7 +1269,17 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]):
13191269
13201270 else :
13211271 # Regular generation path
1322- prompt_ids , prompt_mask = prompt_inputs ["input_ids" ], prompt_inputs ["attention_mask" ]
1272+ generate_inputs = self .processing_class (
1273+ text = prompts_text ,
1274+ return_tensors = "pt" ,
1275+ padding = True ,
1276+ padding_side = "left" ,
1277+ max_length = self .max_prompt_length ,
1278+ truncation = True ,
1279+ add_special_tokens = False ,
1280+ ** kwargs ,
1281+ )
1282+ generate_inputs = super ()._prepare_inputs (generate_inputs )
13231283
13241284 with (
13251285 profiling_context (self , "transformers.generate" ),
@@ -1330,15 +1290,11 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]):
13301290 FSDP .summon_full_params (self .model_wrapped , recurse = False ) if self .is_fsdp_enabled else nullcontext (),
13311291 ):
13321292 prompt_completion_ids = unwrapped_model .generate (
1333- input_ids = prompt_ids ,
1334- attention_mask = prompt_mask ,
1335- ** forward_kwargs ,
1336- generation_config = self .generation_config ,
1337- disable_compile = True ,
1293+ ** generate_inputs , generation_config = self .generation_config , disable_compile = True
13381294 )
13391295 # Compute prompt length and extract completion ids
1296+ prompt_ids , prompt_mask = generate_inputs ["input_ids" ], generate_inputs ["attention_mask" ]
13401297 prompt_length = prompt_ids .size (1 )
1341- prompt_ids = prompt_completion_ids [:, :prompt_length ]
13421298 completion_ids = prompt_completion_ids [:, prompt_length :]
13431299
13441300 # Mask everything after the first EOS token
0 commit comments