diff --git a/verl/trainer/ppo/metric_utils.py b/verl/trainer/ppo/metric_utils.py index 3be65ba978c..876d6907de4 100644 --- a/verl/trainer/ppo/metric_utils.py +++ b/verl/trainer/ppo/metric_utils.py @@ -380,7 +380,7 @@ def calc_maj_val(data: list[dict[str, Any]], vote_key: str, val_key: str) -> flo def process_validation_metrics( - data_sources: list[str], sample_inputs: list[str], infos_dict: dict[str, list[Any]], seed: int = 42 + data_sources: list[str], sample_uids: list[str], infos_dict: dict[str, list[Any]], seed: int = 42 ) -> dict[str, dict[str, dict[str, float]]]: """ Process validation metrics into a structured format with statistical analysis. @@ -392,7 +392,7 @@ def process_validation_metrics( Args: data_sources: List of data source identifiers for each sample. - sample_inputs: List of input prompts corresponding to each sample. + sample_uids: List of sample uids corresponding to each sample. infos_dict: Dictionary mapping variable names to lists of values for each sample. seed: Random seed for bootstrap sampling. Defaults to 42. @@ -418,23 +418,23 @@ def process_validation_metrics( Example: >>> data_sources = ["source1", "source1", "source2"] - >>> sample_inputs = ["prompt1", "prompt1", "prompt2"] + >>> sample_uids = ["uid1", "uid1", "uid2"] >>> infos_dict = {"score": [0.8, 0.9, 0.7], "pred": ["A", "A", "B"]} - >>> result = process_validation_metrics(data_sources, sample_inputs, infos_dict) + >>> result = process_validation_metrics(data_sources, sample_uids, infos_dict) >>> # result will contain statistics for each data source and variable """ # Group metrics by data source, prompt and variable - data_src2prompt2var2vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) + data_src2uid2var2vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) for sample_idx, data_source in enumerate(data_sources): - prompt = sample_inputs[sample_idx] - var2vals = data_src2prompt2var2vals[data_source][prompt] + uid = sample_uids[sample_idx] + var2vals = data_src2uid2var2vals[data_source][uid] for var_name, var_vals in infos_dict.items(): var2vals[var_name].append(var_vals[sample_idx]) # Calculate metrics for each group - data_src2prompt2var2metric = defaultdict(lambda: defaultdict(lambda: defaultdict(dict))) - for data_source, prompt2var2vals in data_src2prompt2var2vals.items(): - for prompt, var2vals in prompt2var2vals.items(): + data_src2uid2var2metric = defaultdict(lambda: defaultdict(lambda: defaultdict(dict))) + for data_source, uid2var2vals in data_src2uid2var2vals.items(): + for uid, var2vals in uid2var2vals.items(): for var_name, var_vals in var2vals.items(): if isinstance(var_vals[0], str): continue @@ -471,20 +471,20 @@ def process_validation_metrics( ) metric[f"maj@{n}/mean"], metric[f"maj@{n}/std"] = maj_n_mean, maj_n_std - data_src2prompt2var2metric[data_source][prompt][var_name] = metric + data_src2uid2var2metric[data_source][uid][var_name] = metric - # Aggregate metrics across prompts - data_src2var2metric2prompt_vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) - for data_source, prompt2var2metric in data_src2prompt2var2metric.items(): - for prompt, var2metric in prompt2var2metric.items(): + # Aggregate metrics across uids + data_src2var2metric2uid_vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) + for data_source, uid2var2metric in data_src2uid2var2metric.items(): + for uid, var2metric in uid2var2metric.items(): for var_name, metric in var2metric.items(): for metric_name, metric_val in metric.items(): - data_src2var2metric2prompt_vals[data_source][var_name][metric_name].append(metric_val) + data_src2var2metric2uid_vals[data_source][var_name][metric_name].append(metric_val) data_src2var2metric2val = defaultdict(lambda: defaultdict(lambda: defaultdict(float))) - for data_source, var2metric2prompt_vals in data_src2var2metric2prompt_vals.items(): - for var_name, metric2prompt_vals in var2metric2prompt_vals.items(): - for metric_name, prompt_vals in metric2prompt_vals.items(): - data_src2var2metric2val[data_source][var_name][metric_name] = np.mean(prompt_vals) + for data_source, var2metric2uid_vals in data_src2var2metric2uid_vals.items(): + for var_name, metric2uid_vals in var2metric2uid_vals.items(): + for metric_name, uid_vals in metric2uid_vals.items(): + data_src2var2metric2val[data_source][var_name][metric_name] = np.mean(uid_vals) return data_src2var2metric2val diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index d2508a1259c..8291a0e2b1e 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -507,10 +507,16 @@ def _validate(self): sample_gts = [] sample_scores = [] sample_turns = [] + sample_uids = [] for test_data in self.val_dataloader: test_batch = DataProto.from_single_dict(test_data) + if "uid" not in test_batch.non_tensor_batch: + test_batch.non_tensor_batch["uid"] = np.array( + [str(uuid.uuid4()) for _ in range(len(test_batch.batch))], dtype=object + ) + # repeat test batch test_batch = test_batch.repeat( repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n, interleave=True @@ -525,6 +531,7 @@ def _validate(self): # TODO: Can we keep special tokens except for padding tokens? input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids] sample_inputs.extend(input_texts) + sample_uids.extend(test_batch.non_tensor_batch["uid"]) ground_truths = [ item.non_tensor_batch.get("reward_model", {}).get("ground_truth", None) for item in test_batch @@ -607,7 +614,7 @@ def _validate(self): data_sources = np.concatenate(data_source_lst, axis=0) - data_src2var2metric2val = process_validation_metrics(data_sources, sample_inputs, reward_extra_infos_dict) + data_src2var2metric2val = process_validation_metrics(data_sources, sample_uids, reward_extra_infos_dict) metric_dict = {} for data_source, var2metric2val in data_src2var2metric2val.items(): core_var = "acc" if "acc" in var2metric2val else "reward"