Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 20 additions & 20 deletions verl/trainer/ppo/metric_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def calc_maj_val(data: list[dict[str, Any]], vote_key: str, val_key: str) -> flo


def process_validation_metrics(
data_sources: list[str], sample_inputs: list[str], infos_dict: dict[str, list[Any]], seed: int = 42
data_sources: list[str], sample_uids: list[str], infos_dict: dict[str, list[Any]], seed: int = 42
) -> dict[str, dict[str, dict[str, float]]]:
"""
Process validation metrics into a structured format with statistical analysis.
Expand All @@ -392,7 +392,7 @@ def process_validation_metrics(

Args:
data_sources: List of data source identifiers for each sample.
sample_inputs: List of input prompts corresponding to each sample.
sample_uids: List of sample uids corresponding to each sample.
infos_dict: Dictionary mapping variable names to lists of values for each sample.
seed: Random seed for bootstrap sampling. Defaults to 42.

Expand All @@ -418,23 +418,23 @@ def process_validation_metrics(

Example:
>>> data_sources = ["source1", "source1", "source2"]
>>> sample_inputs = ["prompt1", "prompt1", "prompt2"]
>>> sample_uids = ["uid1", "uid1", "uid2"]
>>> infos_dict = {"score": [0.8, 0.9, 0.7], "pred": ["A", "A", "B"]}
>>> result = process_validation_metrics(data_sources, sample_inputs, infos_dict)
>>> result = process_validation_metrics(data_sources, sample_uids, infos_dict)
>>> # result will contain statistics for each data source and variable
"""
# Group metrics by data source, prompt and variable
data_src2prompt2var2vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
data_src2uid2var2vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
for sample_idx, data_source in enumerate(data_sources):
prompt = sample_inputs[sample_idx]
var2vals = data_src2prompt2var2vals[data_source][prompt]
uid = sample_uids[sample_idx]
var2vals = data_src2uid2var2vals[data_source][uid]
for var_name, var_vals in infos_dict.items():
var2vals[var_name].append(var_vals[sample_idx])

# Calculate metrics for each group
data_src2prompt2var2metric = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))
for data_source, prompt2var2vals in data_src2prompt2var2vals.items():
for prompt, var2vals in prompt2var2vals.items():
data_src2uid2var2metric = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))
for data_source, uid2var2vals in data_src2uid2var2vals.items():
for uid, var2vals in uid2var2vals.items():
for var_name, var_vals in var2vals.items():
if isinstance(var_vals[0], str):
continue
Expand Down Expand Up @@ -471,20 +471,20 @@ def process_validation_metrics(
)
metric[f"maj@{n}/mean"], metric[f"maj@{n}/std"] = maj_n_mean, maj_n_std

data_src2prompt2var2metric[data_source][prompt][var_name] = metric
data_src2uid2var2metric[data_source][uid][var_name] = metric

# Aggregate metrics across prompts
data_src2var2metric2prompt_vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
for data_source, prompt2var2metric in data_src2prompt2var2metric.items():
for prompt, var2metric in prompt2var2metric.items():
# Aggregate metrics across uids
data_src2var2metric2uid_vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
for data_source, uid2var2metric in data_src2uid2var2metric.items():
for uid, var2metric in uid2var2metric.items():
for var_name, metric in var2metric.items():
for metric_name, metric_val in metric.items():
data_src2var2metric2prompt_vals[data_source][var_name][metric_name].append(metric_val)
data_src2var2metric2uid_vals[data_source][var_name][metric_name].append(metric_val)

data_src2var2metric2val = defaultdict(lambda: defaultdict(lambda: defaultdict(float)))
for data_source, var2metric2prompt_vals in data_src2var2metric2prompt_vals.items():
for var_name, metric2prompt_vals in var2metric2prompt_vals.items():
for metric_name, prompt_vals in metric2prompt_vals.items():
data_src2var2metric2val[data_source][var_name][metric_name] = np.mean(prompt_vals)
for data_source, var2metric2uid_vals in data_src2var2metric2uid_vals.items():
for var_name, metric2uid_vals in var2metric2uid_vals.items():
for metric_name, uid_vals in metric2uid_vals.items():
data_src2var2metric2val[data_source][var_name][metric_name] = np.mean(uid_vals)

return data_src2var2metric2val
9 changes: 8 additions & 1 deletion verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,10 +507,16 @@ def _validate(self):
sample_gts = []
sample_scores = []
sample_turns = []
sample_uids = []

for test_data in self.val_dataloader:
test_batch = DataProto.from_single_dict(test_data)

if "uid" not in test_batch.non_tensor_batch:
test_batch.non_tensor_batch["uid"] = np.array(
[str(uuid.uuid4()) for _ in range(len(test_batch.batch))], dtype=object
)

# repeat test batch
test_batch = test_batch.repeat(
repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n, interleave=True
Expand All @@ -525,6 +531,7 @@ def _validate(self):
# TODO: Can we keep special tokens except for padding tokens?
input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]
sample_inputs.extend(input_texts)
sample_uids.extend(test_batch.non_tensor_batch["uid"])

ground_truths = [
item.non_tensor_batch.get("reward_model", {}).get("ground_truth", None) for item in test_batch
Expand Down Expand Up @@ -607,7 +614,7 @@ def _validate(self):

data_sources = np.concatenate(data_source_lst, axis=0)

data_src2var2metric2val = process_validation_metrics(data_sources, sample_inputs, reward_extra_infos_dict)
data_src2var2metric2val = process_validation_metrics(data_sources, sample_uids, reward_extra_infos_dict)
metric_dict = {}
for data_source, var2metric2val in data_src2var2metric2val.items():
core_var = "acc" if "acc" in var2metric2val else "reward"
Expand Down