diff --git a/tests/test_protocol_on_cpu.py b/tests/test_protocol_on_cpu.py index a1b9253bbe7..a4a48f007de 100644 --- a/tests/test_protocol_on_cpu.py +++ b/tests/test_protocol_on_cpu.py @@ -215,6 +215,144 @@ def test_chunk_concat(): assert concat_data.meta_info == data.meta_info +def test_concat_metrics_from_multiple_workers(): + """Test that concat() properly merges metrics from all workers in distributed training.""" + # Simulate 3 workers each with their own metrics + obs1 = torch.tensor([1, 2]) + obs2 = torch.tensor([3, 4]) + obs3 = torch.tensor([5, 6]) + + # Each worker has different metrics (as list of dict format) + worker1_metrics = [{"loss": 0.5, "accuracy": 0.9}] + worker2_metrics = [{"loss": 0.6, "accuracy": 0.85}] + worker3_metrics = [{"loss": 0.55, "accuracy": 0.88}] + + data1 = DataProto.from_dict(tensors={"obs": obs1}, meta_info={"metrics": worker1_metrics, "config_flag": True}) + data2 = DataProto.from_dict(tensors={"obs": obs2}, meta_info={"metrics": worker2_metrics, "config_flag": True}) + data3 = DataProto.from_dict(tensors={"obs": obs3}, meta_info={"metrics": worker3_metrics, "config_flag": True}) + + # Concat all workers' data + concat_data = DataProto.concat([data1, data2, data3]) + + # Verify tensors are concatenated + assert torch.all(torch.eq(concat_data.batch["obs"], torch.tensor([1, 2, 3, 4, 5, 6]))) + + # Verify ALL workers' metrics are flattened to dict of lists + expected_metrics = {"loss": [0.5, 0.6, 0.55], "accuracy": [0.9, 0.85, 0.88]} + assert concat_data.meta_info["metrics"] == expected_metrics + + # Verify config flags are preserved from first worker + assert concat_data.meta_info["config_flag"] is True + + +def test_concat_with_empty_and_non_list_meta_info(): + """Test concat() handles edge cases: empty meta_info, non-list values, and None.""" + obs1 = torch.tensor([1, 2]) + obs2 = torch.tensor([3, 4]) + + # Worker 1 has metrics, worker 2 doesn't + data1 = DataProto.from_dict(tensors={"obs": obs1}, meta_info={"metrics": [{"loss": 0.5}], "flag": True}) + data2 = DataProto.from_dict(tensors={"obs": obs2}, meta_info={"flag": True}) + + concat_data = DataProto.concat([data1, data2]) + + # Should flatten worker1's metrics to dict of lists + assert concat_data.meta_info["metrics"] == {"loss": [0.5]} + assert concat_data.meta_info["flag"] is True + + # Test with non-list meta_info value + data3 = DataProto.from_dict(tensors={"obs": obs1}, meta_info={"single_value": 42}) + data4 = DataProto.from_dict(tensors={"obs": obs2}, meta_info={"single_value": 42}) + + concat_data2 = DataProto.concat([data3, data4]) + assert concat_data2.meta_info["single_value"] == 42 + + +def test_concat_first_worker_missing_metrics(): + """Test that metrics from other workers are preserved even when first worker has no metrics. + + This is a critical edge case - the old buggy implementation only checked data[0].meta_info + and would lose all metrics if the first worker didn't have any. + """ + obs1 = torch.tensor([1, 2]) + obs2 = torch.tensor([3, 4]) + obs3 = torch.tensor([5, 6]) + + # First worker has NO metrics, but workers 2 and 3 do + data1 = DataProto.from_dict(tensors={"obs": obs1}, meta_info={"config_flag": True}) + data2 = DataProto.from_dict(tensors={"obs": obs2}, meta_info={"metrics": {"loss": 0.6}, "config_flag": True}) + data3 = DataProto.from_dict(tensors={"obs": obs3}, meta_info={"metrics": {"loss": 0.55}, "config_flag": True}) + + concat_data = DataProto.concat([data1, data2, data3]) + + # Should flatten metrics from workers 2 and 3 into dict of lists + expected_metrics = {"loss": [0.6, 0.55]} + assert concat_data.meta_info["metrics"] == expected_metrics + assert concat_data.meta_info["config_flag"] is True + + +def test_concat_non_list_metrics(): + """Test that concat() handles non-list metrics (single dict) correctly. + + In some cases, metrics might be a single dict instead of a list. + The implementation should flatten them into a dict of lists. + """ + obs1 = torch.tensor([1, 2]) + obs2 = torch.tensor([3, 4]) + + # Metrics as single dict (not wrapped in list) + data1 = DataProto.from_dict(tensors={"obs": obs1}, meta_info={"metrics": {"loss": 0.5, "accuracy": 0.9}}) + data2 = DataProto.from_dict(tensors={"obs": obs2}, meta_info={"metrics": {"loss": 0.6, "accuracy": 0.85}}) + + concat_data = DataProto.concat([data1, data2]) + + # Should flatten to dict of lists + expected_metrics = {"loss": [0.5, 0.6], "accuracy": [0.9, 0.85]} + assert concat_data.meta_info["metrics"] == expected_metrics + + +def test_concat_merge_different_non_metric_keys(): + """Test that concat() merges non-metric meta_info keys from all workers. + + When different workers have different non-metric keys, all keys should be preserved. + This prevents silent data loss and aligns with the docstring stating meta_info is "merged". + """ + obs1 = torch.tensor([1, 2]) + obs2 = torch.tensor([3, 4]) + obs3 = torch.tensor([5, 6]) + + # Each worker has some unique non-metric keys + data1 = DataProto.from_dict(tensors={"obs": obs1}, meta_info={"config": "A", "shared_key": "X"}) + data2 = DataProto.from_dict(tensors={"obs": obs2}, meta_info={"extra_key": "B", "shared_key": "X"}) + data3 = DataProto.from_dict(tensors={"obs": obs3}, meta_info={"another_key": "C", "shared_key": "X"}) + + concat_data = DataProto.concat([data1, data2, data3]) + + # All unique keys should be preserved + assert concat_data.meta_info["config"] == "A" + assert concat_data.meta_info["extra_key"] == "B" + assert concat_data.meta_info["another_key"] == "C" + assert concat_data.meta_info["shared_key"] == "X" + + +def test_concat_conflicting_non_metric_keys(): + """Test that concat() raises an assertion error when non-metric keys have conflicting values. + + This ensures data integrity by catching cases where workers have different values + for what should be the same configuration parameter. + """ + obs1 = torch.tensor([1, 2]) + obs2 = torch.tensor([3, 4]) + + # Same key "config" but different values + data1 = DataProto.from_dict(tensors={"obs": obs1}, meta_info={"config": "A"}) + data2 = DataProto.from_dict(tensors={"obs": obs2}, meta_info={"config": "B"}) + + # Should raise an assertion error due to conflicting values + with pytest.raises(AssertionError, match="Conflicting values for meta_info key 'config'"): + DataProto.concat([data1, data2]) + + def test_pop(): obs = torch.randn(100, 10) act = torch.randn(100, 3) diff --git a/verl/protocol.py b/verl/protocol.py index b412dd454e1..0ba386bc894 100644 --- a/verl/protocol.py +++ b/verl/protocol.py @@ -882,7 +882,7 @@ def split(self, split_size: int) -> list["DataProto"]: @staticmethod def concat(data: list["DataProto"]) -> "DataProto": """Concat a list of DataProto. The batch is concatenated among dim=0. - The meta_info is assumed to be identical and will use the first one. + The meta_info is merged, with special handling for metrics from different workers. Args: data (List[DataProto]): list of DataProto @@ -899,8 +899,32 @@ def concat(data: list["DataProto"]) -> "DataProto": for key, val in non_tensor_batch.items(): non_tensor_batch[key] = np.concatenate(val, axis=0) + # Merge meta_info with special handling for metrics + merged_meta_info = {} + if data: + # Merge non-metric meta_info and aggregate metrics from all workers. + all_metrics = [] + for d in data: + for k, v in d.meta_info.items(): + if k == "metrics": + if v is not None: + if isinstance(v, list): + all_metrics.extend(v) + else: + all_metrics.append(v) + else: + if k in merged_meta_info: + # Ensure consistency for overlapping non-metric keys + assert merged_meta_info[k] == v, f"Conflicting values for meta_info key '{k}'" + else: + merged_meta_info[k] = v + + # Flatten list of dicts to dict of lists for consistent metrics structure + if all_metrics: + merged_meta_info["metrics"] = list_of_dict_to_dict_of_list(all_metrics) + cls = type(data[0]) if len(data) > 0 else DataProto - return cls(batch=new_batch, non_tensor_batch=non_tensor_batch, meta_info=data[0].meta_info) + return cls(batch=new_batch, non_tensor_batch=non_tensor_batch, meta_info=merged_meta_info) def reorder(self, indices): """