Skip to content

Commit 394ae74

Browse files
committed
[trainer] chore: Add ground truth data to generation dumps in RayPPOTrainer
1 parent ee65422 commit 394ae74

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

verl/trainer/ppo/ray_trainer.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,7 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl
615615
except Exception as e:
616616
print(f"Warning: Could not set total_training_steps in config. Structure missing? Error: {e}")
617617

618-
def _dump_generations(self, inputs, outputs, scores, reward_extra_infos_dict, dump_path):
618+
def _dump_generations(self, inputs, outputs, gts, scores, reward_extra_infos_dict, dump_path):
619619
"""Dump rollout/validation samples as JSONL."""
620620
os.makedirs(dump_path, exist_ok=True)
621621
filename = os.path.join(dump_path, f"{self.global_steps}.jsonl")
@@ -624,6 +624,7 @@ def _dump_generations(self, inputs, outputs, scores, reward_extra_infos_dict, du
624624
base_data = {
625625
"input": inputs,
626626
"output": outputs,
627+
"gts": gts,
627628
"score": scores,
628629
"step": [self.global_steps] * n,
629630
}
@@ -673,6 +674,7 @@ def _validate(self):
673674
# Lists to collect samples for the table
674675
sample_inputs = []
675676
sample_outputs = []
677+
sample_gts = []
676678
sample_scores = []
677679
sample_turns = []
678680

@@ -694,6 +696,10 @@ def _validate(self):
694696
input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]
695697
sample_inputs.extend(input_texts)
696698

699+
ground_truths = [item.non_tensor_batch.get("reward_model", {}).get(
700+
"ground_truth", None) for item in test_batch]
701+
sample_gts.extend(ground_truths)
702+
697703
batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"]
698704
non_tensor_batch_keys_to_pop = ["raw_prompt_ids"]
699705
if "multi_modal_data" in test_batch.non_tensor_batch:
@@ -772,6 +778,7 @@ def _validate(self):
772778
self._dump_generations(
773779
inputs=sample_inputs,
774780
outputs=sample_outputs,
781+
gts=sample_gts,
775782
scores=sample_scores,
776783
reward_extra_infos_dict=reward_extra_infos_dict,
777784
dump_path=val_data_dir,
@@ -1289,9 +1296,13 @@ def fit(self):
12891296
inputs = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=True)
12901297
outputs = self.tokenizer.batch_decode(batch.batch["responses"], skip_special_tokens=True)
12911298
scores = batch.batch["token_level_scores"].sum(-1).cpu().tolist()
1299+
1300+
sample_gts = [item.non_tensor_batch.get("reward_model", {}).get(
1301+
"ground_truth", None) for item in batch]
12921302
self._dump_generations(
12931303
inputs=inputs,
12941304
outputs=outputs,
1305+
gts=sample_gts,
12951306
scores=scores,
12961307
reward_extra_infos_dict=reward_extra_infos_dict,
12971308
dump_path=rollout_data_dir,

0 commit comments

Comments
 (0)