@@ -572,7 +572,7 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl
572572 except Exception as e :
573573 print (f"Warning: Could not set total_training_steps in config. Structure missing? Error: { e } " )
574574
575- def _dump_generations (self , inputs , outputs , scores , reward_extra_infos_dict , dump_path ):
575+ def _dump_generations (self , inputs , outputs , gts , scores , reward_extra_infos_dict , dump_path ):
576576 """Dump rollout/validation samples as JSONL."""
577577 os .makedirs (dump_path , exist_ok = True )
578578 filename = os .path .join (dump_path , f"{ self .global_steps } .jsonl" )
@@ -581,6 +581,7 @@ def _dump_generations(self, inputs, outputs, scores, reward_extra_infos_dict, du
581581 base_data = {
582582 "input" : inputs ,
583583 "output" : outputs ,
584+ "gts" : gts ,
584585 "score" : scores ,
585586 "step" : [self .global_steps ] * n ,
586587 }
@@ -630,6 +631,7 @@ def _validate(self):
630631 # Lists to collect samples for the table
631632 sample_inputs = []
632633 sample_outputs = []
634+ sample_gts = []
633635 sample_scores = []
634636 sample_turns = []
635637
@@ -651,6 +653,11 @@ def _validate(self):
651653 input_texts = [self .tokenizer .decode (ids , skip_special_tokens = True ) for ids in input_ids ]
652654 sample_inputs .extend (input_texts )
653655
656+ ground_truths = [
657+ item .non_tensor_batch .get ("reward_model" , {}).get ("ground_truth" , None ) for item in test_batch
658+ ]
659+ sample_gts .extend (ground_truths )
660+
654661 batch_keys_to_pop = ["input_ids" , "attention_mask" , "position_ids" ]
655662 non_tensor_batch_keys_to_pop = ["raw_prompt_ids" ]
656663 if "multi_modal_data" in test_batch .non_tensor_batch :
@@ -732,6 +739,7 @@ def _validate(self):
732739 self ._dump_generations (
733740 inputs = sample_inputs ,
734741 outputs = sample_outputs ,
742+ gts = sample_gts ,
735743 scores = sample_scores ,
736744 reward_extra_infos_dict = reward_extra_infos_dict ,
737745 dump_path = val_data_dir ,
@@ -1290,14 +1298,21 @@ def fit(self):
12901298 inputs = self .tokenizer .batch_decode (batch .batch ["prompts" ], skip_special_tokens = True )
12911299 outputs = self .tokenizer .batch_decode (batch .batch ["responses" ], skip_special_tokens = True )
12921300 scores = batch .batch ["token_level_scores" ].sum (- 1 ).cpu ().tolist ()
1301+ sample_gts = [
1302+ item .non_tensor_batch .get ("reward_model" , {}).get ("ground_truth" , None )
1303+ for item in batch
1304+ ]
1305+
12931306 if "request_id" in batch .non_tensor_batch :
12941307 reward_extra_infos_dict .setdefault (
12951308 "request_id" ,
12961309 batch .non_tensor_batch ["request_id" ].tolist (),
12971310 )
1311+
12981312 self ._dump_generations (
12991313 inputs = inputs ,
13001314 outputs = outputs ,
1315+ gts = sample_gts ,
13011316 scores = scores ,
13021317 reward_extra_infos_dict = reward_extra_infos_dict ,
13031318 dump_path = rollout_data_dir ,
0 commit comments