Skip to content

Commit 07b0efc

Browse files
committed
fix: handle list of metric dicts in reduce_metrics after DataProto.concat()
When multiple workers return metrics, DataProto.concat() aggregates them into a list of dicts (introduced in b3c6274). This caused reduce_metrics() to fail with AttributeError: 'list' object has no attribute 'items'. Changes: - Update reduce_metrics() to accept both dict and list of dicts - Merge list of metric dicts before applying reduction operations - Maintain backward compatibility with existing dict input - Add comprehensive tests for new list input handling Fixes the error: File "verl/trainer/ppo/ray_trainer.py", line 1129, in fit critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) AttributeError: 'list' object has no attribute 'items'
1 parent b3c6274 commit 07b0efc

File tree

2 files changed

+66
-2
lines changed

2 files changed

+66
-2
lines changed

tests/trainer/ppo/test_metric_utils_on_cpu.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,45 @@ def test_reduce_metrics_single_value(self):
6666

6767
self.assertEqual(result["single"], 5.0)
6868

69+
def test_reduce_metrics_list_of_dicts(self):
70+
"""Test that reduce_metrics handles list of dicts from multiple workers."""
71+
# Simulate metrics from multiple workers (e.g., after DataProto.concat)
72+
metrics_list = [
73+
{"loss": [1.0, 2.0], "accuracy": [0.8, 0.9]},
74+
{"loss": [3.0, 4.0], "accuracy": [0.7, 0.6]},
75+
]
76+
result = reduce_metrics(metrics_list)
77+
78+
# All values should be merged and averaged
79+
self.assertAlmostEqual(result["loss"], 2.5) # mean of [1.0, 2.0, 3.0, 4.0]
80+
self.assertAlmostEqual(result["accuracy"], 0.75) # mean of [0.8, 0.9, 0.7, 0.6]
81+
82+
def test_reduce_metrics_list_of_dicts_with_scalars(self):
83+
"""Test that reduce_metrics handles list of dicts with scalar values."""
84+
# Simulate metrics from multiple workers where each worker has scalar values
85+
metrics_list = [
86+
{"loss": 1.0, "accuracy": 0.8},
87+
{"loss": 3.0, "accuracy": 0.6},
88+
]
89+
result = reduce_metrics(metrics_list)
90+
91+
# All values should be merged and averaged
92+
self.assertEqual(result["loss"], 2.0) # mean of [1.0, 3.0]
93+
self.assertEqual(result["accuracy"], 0.7) # mean of [0.8, 0.6]
94+
95+
def test_reduce_metrics_list_with_max_min_keys(self):
96+
"""Test that reduce_metrics correctly applies max/min reduction for list input."""
97+
metrics_list = [
98+
{"max_reward": [5.0, 8.0], "min_error": [0.1, 0.05]},
99+
{"max_reward": [6.0, 7.0], "min_error": [0.2, 0.15]},
100+
]
101+
result = reduce_metrics(metrics_list)
102+
103+
# max_reward should use max aggregation
104+
self.assertEqual(result["max_reward"], 8.0) # max of [5.0, 8.0, 6.0, 7.0]
105+
# min_error should use min aggregation
106+
self.assertEqual(result["min_error"], 0.05) # min of [0.1, 0.05, 0.2, 0.15]
107+
69108

70109
class TestComputeDataMetrics(unittest.TestCase):
71110
"""Tests for the compute_data_metrics function."""

verl/utils/metric/utils.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import numpy as np
2121

2222

23-
def reduce_metrics(metrics: dict[str, list[Any]]) -> dict[str, Any]:
23+
def reduce_metrics(metrics: dict[str, list[Any]] | list[dict[str, Any]]) -> dict[str, Any]:
2424
"""
2525
Reduces a dictionary of metric lists by computing the mean, max, or min of each list.
2626
The reduce operation is determined by the key name:
@@ -29,7 +29,9 @@ def reduce_metrics(metrics: dict[str, list[Any]]) -> dict[str, Any]:
2929
- Otherwise, np.mean is used
3030
3131
Args:
32-
metrics: A dictionary mapping metric names to lists of metric values.
32+
metrics: Either:
33+
- A dictionary mapping metric names to lists of metric values, or
34+
- A list of dictionaries from multiple workers (e.g., after DataProto.concat())
3335
3436
Returns:
3537
A dictionary with the same keys but with each list replaced by its reduced value.
@@ -43,7 +45,30 @@ def reduce_metrics(metrics: dict[str, list[Any]]) -> dict[str, Any]:
4345
... }
4446
>>> reduce_metrics(metrics)
4547
{"loss": 2.0, "accuracy": 0.8, "max_reward": 8.0, "min_error": 0.05}
48+
49+
>>> metrics_list = [
50+
... {"loss": [1.0, 2.0], "accuracy": [0.8, 0.9]},
51+
... {"loss": [3.0, 4.0], "accuracy": [0.7, 0.6]}
52+
... ]
53+
>>> reduce_metrics(metrics_list)
54+
{"loss": 2.5, "accuracy": 0.75}
4655
"""
56+
# Handle list of dicts (from multiple workers after DataProto.concat)
57+
if isinstance(metrics, list):
58+
# Merge all metric dicts into a single dict with lists
59+
merged_metrics = {}
60+
for worker_metrics in metrics:
61+
for key, val in worker_metrics.items():
62+
if key not in merged_metrics:
63+
merged_metrics[key] = []
64+
# val could be a single value or a list
65+
if isinstance(val, list):
66+
merged_metrics[key].extend(val)
67+
else:
68+
merged_metrics[key].append(val)
69+
metrics = merged_metrics
70+
71+
# Now reduce the dict of lists
4772
for key, val in metrics.items():
4873
if "max" in key:
4974
metrics[key] = np.max(val)

0 commit comments

Comments
 (0)