Skip to content

Commit 483cd55

Browse files
authored
[trainer] chore: Add ground truth data to generation dumps in RayPPOTrainer (#2353)
1 parent 6017c9e commit 483cd55

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

verl/trainer/ppo/ray_trainer.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -572,7 +572,7 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl
572572
except Exception as e:
573573
print(f"Warning: Could not set total_training_steps in config. Structure missing? Error: {e}")
574574

575-
def _dump_generations(self, inputs, outputs, scores, reward_extra_infos_dict, dump_path):
575+
def _dump_generations(self, inputs, outputs, gts, scores, reward_extra_infos_dict, dump_path):
576576
"""Dump rollout/validation samples as JSONL."""
577577
os.makedirs(dump_path, exist_ok=True)
578578
filename = os.path.join(dump_path, f"{self.global_steps}.jsonl")
@@ -581,6 +581,7 @@ def _dump_generations(self, inputs, outputs, scores, reward_extra_infos_dict, du
581581
base_data = {
582582
"input": inputs,
583583
"output": outputs,
584+
"gts": gts,
584585
"score": scores,
585586
"step": [self.global_steps] * n,
586587
}
@@ -630,6 +631,7 @@ def _validate(self):
630631
# Lists to collect samples for the table
631632
sample_inputs = []
632633
sample_outputs = []
634+
sample_gts = []
633635
sample_scores = []
634636
sample_turns = []
635637

@@ -651,6 +653,11 @@ def _validate(self):
651653
input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]
652654
sample_inputs.extend(input_texts)
653655

656+
ground_truths = [
657+
item.non_tensor_batch.get("reward_model", {}).get("ground_truth", None) for item in test_batch
658+
]
659+
sample_gts.extend(ground_truths)
660+
654661
batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"]
655662
non_tensor_batch_keys_to_pop = ["raw_prompt_ids"]
656663
if "multi_modal_data" in test_batch.non_tensor_batch:
@@ -732,6 +739,7 @@ def _validate(self):
732739
self._dump_generations(
733740
inputs=sample_inputs,
734741
outputs=sample_outputs,
742+
gts=sample_gts,
735743
scores=sample_scores,
736744
reward_extra_infos_dict=reward_extra_infos_dict,
737745
dump_path=val_data_dir,
@@ -1290,14 +1298,21 @@ def fit(self):
12901298
inputs = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=True)
12911299
outputs = self.tokenizer.batch_decode(batch.batch["responses"], skip_special_tokens=True)
12921300
scores = batch.batch["token_level_scores"].sum(-1).cpu().tolist()
1301+
sample_gts = [
1302+
item.non_tensor_batch.get("reward_model", {}).get("ground_truth", None)
1303+
for item in batch
1304+
]
1305+
12931306
if "request_id" in batch.non_tensor_batch:
12941307
reward_extra_infos_dict.setdefault(
12951308
"request_id",
12961309
batch.non_tensor_batch["request_id"].tolist(),
12971310
)
1311+
12981312
self._dump_generations(
12991313
inputs=inputs,
13001314
outputs=outputs,
1315+
gts=sample_gts,
13011316
scores=scores,
13021317
reward_extra_infos_dict=reward_extra_infos_dict,
13031318
dump_path=rollout_data_dir,

0 commit comments

Comments
 (0)