@@ -609,7 +609,7 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl
609609 except Exception as e :
610610 print (f"Warning: Could not set total_training_steps in config. Structure missing? Error: { e } " )
611611
612- def _dump_generations (self , inputs , outputs , scores , reward_extra_infos_dict , dump_path ):
612+ def _dump_generations (self , inputs , outputs , gts , scores , reward_extra_infos_dict , dump_path ):
613613 """Dump rollout/validation samples as JSONL."""
614614 os .makedirs (dump_path , exist_ok = True )
615615 filename = os .path .join (dump_path , f"{ self .global_steps } .jsonl" )
@@ -618,6 +618,7 @@ def _dump_generations(self, inputs, outputs, scores, reward_extra_infos_dict, du
618618 base_data = {
619619 "input" : inputs ,
620620 "output" : outputs ,
621+ "gts" : gts ,
621622 "score" : scores ,
622623 "step" : [self .global_steps ] * n ,
623624 }
@@ -667,6 +668,7 @@ def _validate(self):
667668 # Lists to collect samples for the table
668669 sample_inputs = []
669670 sample_outputs = []
671+ sample_gts = []
670672 sample_scores = []
671673 sample_turns = []
672674
@@ -688,6 +690,10 @@ def _validate(self):
688690 input_texts = [self .tokenizer .decode (ids , skip_special_tokens = True ) for ids in input_ids ]
689691 sample_inputs .extend (input_texts )
690692
693+ ground_truths = [item .non_tensor_batch .get ("reward_model" , {}).get (
694+ "ground_truth" , None ) for item in test_batch ]
695+ sample_gts .extend (ground_truths )
696+
691697 batch_keys_to_pop = ["input_ids" , "attention_mask" , "position_ids" ]
692698 non_tensor_batch_keys_to_pop = ["raw_prompt_ids" ]
693699 if "multi_modal_data" in test_batch .non_tensor_batch :
@@ -766,6 +772,7 @@ def _validate(self):
766772 self ._dump_generations (
767773 inputs = sample_inputs ,
768774 outputs = sample_outputs ,
775+ gts = sample_gts ,
769776 scores = sample_scores ,
770777 reward_extra_infos_dict = reward_extra_infos_dict ,
771778 dump_path = val_data_dir ,
@@ -1303,9 +1310,13 @@ def fit(self):
13031310 inputs = self .tokenizer .batch_decode (batch .batch ["prompts" ], skip_special_tokens = True )
13041311 outputs = self .tokenizer .batch_decode (batch .batch ["responses" ], skip_special_tokens = True )
13051312 scores = batch .batch ["token_level_scores" ].sum (- 1 ).cpu ().tolist ()
1313+
1314+ sample_gts = [item .non_tensor_batch .get ("reward_model" , {}).get (
1315+ "ground_truth" , None ) for item in batch ]
13061316 self ._dump_generations (
13071317 inputs = inputs ,
13081318 outputs = outputs ,
1319+ gts = sample_gts ,
13091320 scores = scores ,
13101321 reward_extra_infos_dict = reward_extra_infos_dict ,
13111322 dump_path = rollout_data_dir ,
0 commit comments