Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 138 additions & 0 deletions tests/test_protocol_on_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,144 @@ def test_chunk_concat():
assert concat_data.meta_info == data.meta_info


def test_concat_metrics_from_multiple_workers():
"""Test that concat() properly merges metrics from all workers in distributed training."""
# Simulate 3 workers each with their own metrics
obs1 = torch.tensor([1, 2])
obs2 = torch.tensor([3, 4])
obs3 = torch.tensor([5, 6])

# Each worker has different metrics (as list of dict format)
worker1_metrics = [{"loss": 0.5, "accuracy": 0.9}]
worker2_metrics = [{"loss": 0.6, "accuracy": 0.85}]
worker3_metrics = [{"loss": 0.55, "accuracy": 0.88}]

data1 = DataProto.from_dict(tensors={"obs": obs1}, meta_info={"metrics": worker1_metrics, "config_flag": True})
data2 = DataProto.from_dict(tensors={"obs": obs2}, meta_info={"metrics": worker2_metrics, "config_flag": True})
data3 = DataProto.from_dict(tensors={"obs": obs3}, meta_info={"metrics": worker3_metrics, "config_flag": True})

# Concat all workers' data
concat_data = DataProto.concat([data1, data2, data3])

# Verify tensors are concatenated
assert torch.all(torch.eq(concat_data.batch["obs"], torch.tensor([1, 2, 3, 4, 5, 6])))

# Verify ALL workers' metrics are flattened to dict of lists
expected_metrics = {"loss": [0.5, 0.6, 0.55], "accuracy": [0.9, 0.85, 0.88]}
assert concat_data.meta_info["metrics"] == expected_metrics

# Verify config flags are preserved from first worker
assert concat_data.meta_info["config_flag"] is True


def test_concat_with_empty_and_non_list_meta_info():
"""Test concat() handles edge cases: empty meta_info, non-list values, and None."""
obs1 = torch.tensor([1, 2])
obs2 = torch.tensor([3, 4])

# Worker 1 has metrics, worker 2 doesn't
data1 = DataProto.from_dict(tensors={"obs": obs1}, meta_info={"metrics": [{"loss": 0.5}], "flag": True})
data2 = DataProto.from_dict(tensors={"obs": obs2}, meta_info={"flag": True})

concat_data = DataProto.concat([data1, data2])

# Should flatten worker1's metrics to dict of lists
assert concat_data.meta_info["metrics"] == {"loss": [0.5]}
assert concat_data.meta_info["flag"] is True

# Test with non-list meta_info value
data3 = DataProto.from_dict(tensors={"obs": obs1}, meta_info={"single_value": 42})
data4 = DataProto.from_dict(tensors={"obs": obs2}, meta_info={"single_value": 42})

concat_data2 = DataProto.concat([data3, data4])
assert concat_data2.meta_info["single_value"] == 42


def test_concat_first_worker_missing_metrics():
"""Test that metrics from other workers are preserved even when first worker has no metrics.

This is a critical edge case - the old buggy implementation only checked data[0].meta_info
and would lose all metrics if the first worker didn't have any.
"""
obs1 = torch.tensor([1, 2])
obs2 = torch.tensor([3, 4])
obs3 = torch.tensor([5, 6])

# First worker has NO metrics, but workers 2 and 3 do
data1 = DataProto.from_dict(tensors={"obs": obs1}, meta_info={"config_flag": True})
data2 = DataProto.from_dict(tensors={"obs": obs2}, meta_info={"metrics": {"loss": 0.6}, "config_flag": True})
data3 = DataProto.from_dict(tensors={"obs": obs3}, meta_info={"metrics": {"loss": 0.55}, "config_flag": True})

concat_data = DataProto.concat([data1, data2, data3])

# Should flatten metrics from workers 2 and 3 into dict of lists
expected_metrics = {"loss": [0.6, 0.55]}
assert concat_data.meta_info["metrics"] == expected_metrics
assert concat_data.meta_info["config_flag"] is True


def test_concat_non_list_metrics():
"""Test that concat() handles non-list metrics (single dict) correctly.

In some cases, metrics might be a single dict instead of a list.
The implementation should flatten them into a dict of lists.
"""
obs1 = torch.tensor([1, 2])
obs2 = torch.tensor([3, 4])

# Metrics as single dict (not wrapped in list)
data1 = DataProto.from_dict(tensors={"obs": obs1}, meta_info={"metrics": {"loss": 0.5, "accuracy": 0.9}})
data2 = DataProto.from_dict(tensors={"obs": obs2}, meta_info={"metrics": {"loss": 0.6, "accuracy": 0.85}})

concat_data = DataProto.concat([data1, data2])

# Should flatten to dict of lists
expected_metrics = {"loss": [0.5, 0.6], "accuracy": [0.9, 0.85]}
assert concat_data.meta_info["metrics"] == expected_metrics


def test_concat_merge_different_non_metric_keys():
"""Test that concat() merges non-metric meta_info keys from all workers.

When different workers have different non-metric keys, all keys should be preserved.
This prevents silent data loss and aligns with the docstring stating meta_info is "merged".
"""
obs1 = torch.tensor([1, 2])
obs2 = torch.tensor([3, 4])
obs3 = torch.tensor([5, 6])

# Each worker has some unique non-metric keys
data1 = DataProto.from_dict(tensors={"obs": obs1}, meta_info={"config": "A", "shared_key": "X"})
data2 = DataProto.from_dict(tensors={"obs": obs2}, meta_info={"extra_key": "B", "shared_key": "X"})
data3 = DataProto.from_dict(tensors={"obs": obs3}, meta_info={"another_key": "C", "shared_key": "X"})

concat_data = DataProto.concat([data1, data2, data3])

# All unique keys should be preserved
assert concat_data.meta_info["config"] == "A"
assert concat_data.meta_info["extra_key"] == "B"
assert concat_data.meta_info["another_key"] == "C"
assert concat_data.meta_info["shared_key"] == "X"


def test_concat_conflicting_non_metric_keys():
"""Test that concat() raises an assertion error when non-metric keys have conflicting values.

This ensures data integrity by catching cases where workers have different values
for what should be the same configuration parameter.
"""
obs1 = torch.tensor([1, 2])
obs2 = torch.tensor([3, 4])

# Same key "config" but different values
data1 = DataProto.from_dict(tensors={"obs": obs1}, meta_info={"config": "A"})
data2 = DataProto.from_dict(tensors={"obs": obs2}, meta_info={"config": "B"})

# Should raise an assertion error due to conflicting values
with pytest.raises(AssertionError, match="Conflicting values for meta_info key 'config'"):
DataProto.concat([data1, data2])


def test_pop():
obs = torch.randn(100, 10)
act = torch.randn(100, 3)
Expand Down
28 changes: 26 additions & 2 deletions verl/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,7 +882,7 @@ def split(self, split_size: int) -> list["DataProto"]:
@staticmethod
def concat(data: list["DataProto"]) -> "DataProto":
"""Concat a list of DataProto. The batch is concatenated among dim=0.
The meta_info is assumed to be identical and will use the first one.
The meta_info is merged, with special handling for metrics from different workers.

Args:
data (List[DataProto]): list of DataProto
Expand All @@ -899,8 +899,32 @@ def concat(data: list["DataProto"]) -> "DataProto":
for key, val in non_tensor_batch.items():
non_tensor_batch[key] = np.concatenate(val, axis=0)

# Merge meta_info with special handling for metrics
merged_meta_info = {}
if data:
# Merge non-metric meta_info and aggregate metrics from all workers.
all_metrics = []
for d in data:
for k, v in d.meta_info.items():
if k == "metrics":
if v is not None:
if isinstance(v, list):
all_metrics.extend(v)
else:
all_metrics.append(v)
else:
if k in merged_meta_info:
# Ensure consistency for overlapping non-metric keys
assert merged_meta_info[k] == v, f"Conflicting values for meta_info key '{k}'"
else:
merged_meta_info[k] = v

# Flatten list of dicts to dict of lists for consistent metrics structure
if all_metrics:
merged_meta_info["metrics"] = list_of_dict_to_dict_of_list(all_metrics)

cls = type(data[0]) if len(data) > 0 else DataProto
return cls(batch=new_batch, non_tensor_batch=non_tensor_batch, meta_info=data[0].meta_info)
return cls(batch=new_batch, non_tensor_batch=non_tensor_batch, meta_info=merged_meta_info)

def reorder(self, indices):
"""
Expand Down