Skip to content

Commit ae9b52f

Browse files
Maxwell-JiaMaxwell-Jiagemini-code-assist[bot]
authored andcommitted
[misc] fix: use uid for grouping in validation to avoid prompt confusion in multimodal tasks (volcengine#3280)
### What does this PR do? Fix volcengine#3238. Follow volcengine#2815. volcengine#2815 seems to have no follow-up process. This PR switched from text prompt to grouping by uid when calculating validation metrics. ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: volcengine#2815. - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Checklist Before Submitting - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [x] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [x] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [x] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --------- Co-authored-by: Maxwell-Jia <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent d88d4c5 commit ae9b52f

File tree

2 files changed

+28
-21
lines changed

2 files changed

+28
-21
lines changed

verl/trainer/ppo/metric_utils.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,7 @@ def calc_maj_val(data: list[dict[str, Any]], vote_key: str, val_key: str) -> flo
380380

381381

382382
def process_validation_metrics(
383-
data_sources: list[str], sample_inputs: list[str], infos_dict: dict[str, list[Any]], seed: int = 42
383+
data_sources: list[str], sample_uids: list[str], infos_dict: dict[str, list[Any]], seed: int = 42
384384
) -> dict[str, dict[str, dict[str, float]]]:
385385
"""
386386
Process validation metrics into a structured format with statistical analysis.
@@ -392,7 +392,7 @@ def process_validation_metrics(
392392
393393
Args:
394394
data_sources: List of data source identifiers for each sample.
395-
sample_inputs: List of input prompts corresponding to each sample.
395+
sample_uids: List of sample uids corresponding to each sample.
396396
infos_dict: Dictionary mapping variable names to lists of values for each sample.
397397
seed: Random seed for bootstrap sampling. Defaults to 42.
398398
@@ -418,23 +418,23 @@ def process_validation_metrics(
418418
419419
Example:
420420
>>> data_sources = ["source1", "source1", "source2"]
421-
>>> sample_inputs = ["prompt1", "prompt1", "prompt2"]
421+
>>> sample_uids = ["uid1", "uid1", "uid2"]
422422
>>> infos_dict = {"score": [0.8, 0.9, 0.7], "pred": ["A", "A", "B"]}
423-
>>> result = process_validation_metrics(data_sources, sample_inputs, infos_dict)
423+
>>> result = process_validation_metrics(data_sources, sample_uids, infos_dict)
424424
>>> # result will contain statistics for each data source and variable
425425
"""
426426
# Group metrics by data source, prompt and variable
427-
data_src2prompt2var2vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
427+
data_src2uid2var2vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
428428
for sample_idx, data_source in enumerate(data_sources):
429-
prompt = sample_inputs[sample_idx]
430-
var2vals = data_src2prompt2var2vals[data_source][prompt]
429+
uid = sample_uids[sample_idx]
430+
var2vals = data_src2uid2var2vals[data_source][uid]
431431
for var_name, var_vals in infos_dict.items():
432432
var2vals[var_name].append(var_vals[sample_idx])
433433

434434
# Calculate metrics for each group
435-
data_src2prompt2var2metric = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))
436-
for data_source, prompt2var2vals in data_src2prompt2var2vals.items():
437-
for prompt, var2vals in prompt2var2vals.items():
435+
data_src2uid2var2metric = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))
436+
for data_source, uid2var2vals in data_src2uid2var2vals.items():
437+
for uid, var2vals in uid2var2vals.items():
438438
for var_name, var_vals in var2vals.items():
439439
if isinstance(var_vals[0], str):
440440
continue
@@ -471,20 +471,20 @@ def process_validation_metrics(
471471
)
472472
metric[f"maj@{n}/mean"], metric[f"maj@{n}/std"] = maj_n_mean, maj_n_std
473473

474-
data_src2prompt2var2metric[data_source][prompt][var_name] = metric
474+
data_src2uid2var2metric[data_source][uid][var_name] = metric
475475

476-
# Aggregate metrics across prompts
477-
data_src2var2metric2prompt_vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
478-
for data_source, prompt2var2metric in data_src2prompt2var2metric.items():
479-
for prompt, var2metric in prompt2var2metric.items():
476+
# Aggregate metrics across uids
477+
data_src2var2metric2uid_vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
478+
for data_source, uid2var2metric in data_src2uid2var2metric.items():
479+
for uid, var2metric in uid2var2metric.items():
480480
for var_name, metric in var2metric.items():
481481
for metric_name, metric_val in metric.items():
482-
data_src2var2metric2prompt_vals[data_source][var_name][metric_name].append(metric_val)
482+
data_src2var2metric2uid_vals[data_source][var_name][metric_name].append(metric_val)
483483

484484
data_src2var2metric2val = defaultdict(lambda: defaultdict(lambda: defaultdict(float)))
485-
for data_source, var2metric2prompt_vals in data_src2var2metric2prompt_vals.items():
486-
for var_name, metric2prompt_vals in var2metric2prompt_vals.items():
487-
for metric_name, prompt_vals in metric2prompt_vals.items():
488-
data_src2var2metric2val[data_source][var_name][metric_name] = np.mean(prompt_vals)
485+
for data_source, var2metric2uid_vals in data_src2var2metric2uid_vals.items():
486+
for var_name, metric2uid_vals in var2metric2uid_vals.items():
487+
for metric_name, uid_vals in metric2uid_vals.items():
488+
data_src2var2metric2val[data_source][var_name][metric_name] = np.mean(uid_vals)
489489

490490
return data_src2var2metric2val

verl/trainer/ppo/ray_trainer.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -507,10 +507,16 @@ def _validate(self):
507507
sample_gts = []
508508
sample_scores = []
509509
sample_turns = []
510+
sample_uids = []
510511

511512
for test_data in self.val_dataloader:
512513
test_batch = DataProto.from_single_dict(test_data)
513514

515+
if "uid" not in test_batch.non_tensor_batch:
516+
test_batch.non_tensor_batch["uid"] = np.array(
517+
[str(uuid.uuid4()) for _ in range(len(test_batch.batch))], dtype=object
518+
)
519+
514520
# repeat test batch
515521
test_batch = test_batch.repeat(
516522
repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n, interleave=True
@@ -525,6 +531,7 @@ def _validate(self):
525531
# TODO: Can we keep special tokens except for padding tokens?
526532
input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]
527533
sample_inputs.extend(input_texts)
534+
sample_uids.extend(test_batch.non_tensor_batch["uid"])
528535

529536
ground_truths = [
530537
item.non_tensor_batch.get("reward_model", {}).get("ground_truth", None) for item in test_batch
@@ -607,7 +614,7 @@ def _validate(self):
607614

608615
data_sources = np.concatenate(data_source_lst, axis=0)
609616

610-
data_src2var2metric2val = process_validation_metrics(data_sources, sample_inputs, reward_extra_infos_dict)
617+
data_src2var2metric2val = process_validation_metrics(data_sources, sample_uids, reward_extra_infos_dict)
611618
metric_dict = {}
612619
for data_source, var2metric2val in data_src2var2metric2val.items():
613620
core_var = "acc" if "acc" in var2metric2val else "reward"

0 commit comments

Comments
 (0)