Skip to content

Commit 1ac5a21

Browse files
committed
[trainer] chore: Add ground truth data to generation dumps in RayPPOTrainer
1 parent bc2cc6b commit 1ac5a21

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

verl/trainer/ppo/ray_trainer.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -609,7 +609,7 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl
609609
except Exception as e:
610610
print(f"Warning: Could not set total_training_steps in config. Structure missing? Error: {e}")
611611

612-
def _dump_generations(self, inputs, outputs, scores, reward_extra_infos_dict, dump_path):
612+
def _dump_generations(self, inputs, outputs, gts, scores, reward_extra_infos_dict, dump_path):
613613
"""Dump rollout/validation samples as JSONL."""
614614
os.makedirs(dump_path, exist_ok=True)
615615
filename = os.path.join(dump_path, f"{self.global_steps}.jsonl")
@@ -618,6 +618,7 @@ def _dump_generations(self, inputs, outputs, scores, reward_extra_infos_dict, du
618618
base_data = {
619619
"input": inputs,
620620
"output": outputs,
621+
"gts": gts,
621622
"score": scores,
622623
"step": [self.global_steps] * n,
623624
}
@@ -667,6 +668,7 @@ def _validate(self):
667668
# Lists to collect samples for the table
668669
sample_inputs = []
669670
sample_outputs = []
671+
sample_gts = []
670672
sample_scores = []
671673
sample_turns = []
672674

@@ -688,6 +690,10 @@ def _validate(self):
688690
input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]
689691
sample_inputs.extend(input_texts)
690692

693+
ground_truths = [item.non_tensor_batch.get("reward_model", {}).get(
694+
"ground_truth", None) for item in test_batch]
695+
sample_gts.extend(ground_truths)
696+
691697
batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"]
692698
non_tensor_batch_keys_to_pop = ["raw_prompt_ids"]
693699
if "multi_modal_data" in test_batch.non_tensor_batch:
@@ -766,6 +772,7 @@ def _validate(self):
766772
self._dump_generations(
767773
inputs=sample_inputs,
768774
outputs=sample_outputs,
775+
gts=sample_gts,
769776
scores=sample_scores,
770777
reward_extra_infos_dict=reward_extra_infos_dict,
771778
dump_path=val_data_dir,
@@ -1303,9 +1310,13 @@ def fit(self):
13031310
inputs = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=True)
13041311
outputs = self.tokenizer.batch_decode(batch.batch["responses"], skip_special_tokens=True)
13051312
scores = batch.batch["token_level_scores"].sum(-1).cpu().tolist()
1313+
1314+
sample_gts = [item.non_tensor_batch.get("reward_model", {}).get(
1315+
"ground_truth", None) for item in batch]
13061316
self._dump_generations(
13071317
inputs=inputs,
13081318
outputs=outputs,
1319+
gts=sample_gts,
13091320
scores=scores,
13101321
reward_extra_infos_dict=reward_extra_infos_dict,
13111322
dump_path=rollout_data_dir,

0 commit comments

Comments
 (0)