@@ -1086,13 +1086,6 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]):
10861086 maybe_apply_chat_template ({"prompt" : prompt }, self .processing_class )["prompt" ] for prompt in prompts
10871087 ]
10881088
1089- if images is not None :
1090- prompt_inputs = self .processing_class (text = prompts_text , padding = True , return_tensors = "pt" , ** kwargs )
1091- prompt_inputs = super ()._prepare_inputs (prompt_inputs )
1092- forward_kwargs = {k : v for k , v in prompt_inputs .items () if k not in ["input_ids" , "attention_mask" ]}
1093- else :
1094- forward_kwargs = {}
1095-
10961089 # Generate completions using either vLLM or regular generation
10971090 if self .use_vllm :
10981091 if self .vllm_mode == "colocate" and self .args .vllm_enable_sleep_mode :
@@ -1307,13 +1300,13 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]):
13071300 completion_ids = [c [m ].tolist () for c , m in zip (completion_ids , completion_mask .bool ())]
13081301 logprobs = None # not used in this case
13091302
1310- return prompt_ids , completion_ids , logprobs , forward_kwargs
1303+ return prompt_ids , completion_ids , logprobs
13111304
13121305 def _generate (self , prompts : list [str ], images : Optional [list ]):
13131306 device = self .accelerator .device
13141307 mode = "train" if self .model .training else "eval"
13151308
1316- prompt_ids , completion_ids , logprobs , forward_kwargs = self ._generate_single_turn (prompts , images )
1309+ prompt_ids , completion_ids , logprobs = self ._generate_single_turn (prompts , images )
13171310
13181311 # Get completion length per sequence, used for logging
13191312 prompt_lengths = torch .tensor ([len (ids ) for ids in prompt_ids ], device = device )
@@ -1345,7 +1338,7 @@ def _generate(self, prompts: list[str], images: Optional[list]):
13451338 self ._metrics [mode ]["completions/min_terminated_length" ].append (term_completion_lengths .float ().min ().item ())
13461339 self ._metrics [mode ]["completions/max_terminated_length" ].append (term_completion_lengths .float ().max ().item ())
13471340
1348- return prompt_ids , completion_ids , total_completion_tokens , logprobs , forward_kwargs
1341+ return prompt_ids , completion_ids , total_completion_tokens , logprobs
13491342
13501343 def _generate_and_score_completions (
13511344 self , inputs : list [dict [str , Union [torch .Tensor , Any ]]]
@@ -1365,13 +1358,9 @@ def _generate_and_score_completions(
13651358 if images is not None and all (img_list == [] for img_list in images ):
13661359 images = None
13671360
1368- (
1369- prompt_ids_list ,
1370- completion_ids_list ,
1371- num_items_in_batch ,
1372- sampling_per_token_logps_list ,
1373- forward_kwargs ,
1374- ) = self ._generate (prompts , images )
1361+ prompt_ids_list , completion_ids_list , num_items_in_batch , sampling_per_token_logps_list = self ._generate (
1362+ prompts , images
1363+ )
13751364
13761365 # Convert lists of token IDs to padded tensors
13771366 prompt_ids = [torch .tensor (ids , device = device ) for ids in prompt_ids_list ]
@@ -1397,18 +1386,30 @@ def _generate_and_score_completions(
13971386 # Concatenate prompt_mask with completion_mask for logit computation
13981387 prompt_completion_ids = torch .cat ([prompt_ids , completion_ids ], dim = 1 ) # (B, P+C)
13991388 attention_mask = torch .cat ([prompt_mask , completion_mask ], dim = 1 ) # (B, P+C)
1389+
1390+ logits_to_keep = completion_ids .size (1 ) # we only need to compute the logits for the completion tokens
1391+ batch_size = self .args .per_device_train_batch_size if mode == "train" else self .args .per_device_eval_batch_size
1392+
1393+ num_images = [len (img_list ) for img_list in images ] if images is not None else None
1394+
1395+ # Get forward_kwargs for models with multimodal inputs
1396+ if images is not None :
1397+ prompts_text = [
1398+ apply_chat_template ({"prompt" : prompt }, self .processing_class )["prompt" ] for prompt in prompts
1399+ ]
1400+ prompt_inputs = self .processing_class (images = images , text = prompts_text , padding = True , return_tensors = "pt" )
1401+ prompt_inputs = super ()._prepare_inputs (prompt_inputs )
1402+ forward_kwargs = {k : v for k , v in prompt_inputs .items () if k not in ["input_ids" , "attention_mask" ]}
1403+ else :
1404+ forward_kwargs = {}
1405+
14001406 # If token_type_ids are used, extend them with zeros for the completion part
14011407 if "token_type_ids" in forward_kwargs :
14021408 token_type_ids = forward_kwargs ["token_type_ids" ]
14031409 forward_kwargs ["token_type_ids" ] = torch .cat (
14041410 [token_type_ids , token_type_ids .new_zeros (completion_ids .shape )], dim = 1
14051411 )
14061412
1407- logits_to_keep = completion_ids .size (1 ) # we only need to compute the logits for the completion tokens
1408- batch_size = self .args .per_device_train_batch_size if mode == "train" else self .args .per_device_eval_batch_size
1409-
1410- num_images = [len (img_list ) for img_list in images ] if images is not None else None
1411-
14121413 with torch .no_grad ():
14131414 # If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of
14141415 # a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the
0 commit comments