Skip to content

Commit 887b16d

Browse files
committed
[misc] fix: Handle N-D arrays and complex objects in union_numpy_dict (volcengine#2768)
### What does this PR do? This PR fixes a bug in `verl.protocol.union_numpy_dict` where it would crash on NumPy arrays with more than 2 dimensions. It replaces the underlying comparison logic with a robust, recursive function that can handle N-D arrays, nested objects, `NaN` values, and circular references. This resolves issue volcengine#2766. ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test A comprehensive unit test suite has been added to `tests/test_protocol_on_cpu.py`. The new tests cover the following scenarios, all of which now pass: * Merging dictionaries with identical 3D (and higher) dimensional arrays. * Correctly failing when N-D arrays with the same shape but different values are merged. * Handling nested `object`-dtype arrays containing other arrays, strings, and `None`. * Correctly treating `NaN` values at the same position as equal, mimicking pandas' behavior. * Safely handling circular references without causing a `RecursionError`. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [x] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
1 parent 2e1a1a6 commit 887b16d

File tree

2 files changed

+157
-16
lines changed

2 files changed

+157
-16
lines changed

tests/test_protocol_on_cpu.py

Lines changed: 90 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,21 +33,99 @@ def test_union_tensor_dict():
3333
{"obs": obs.clone(), "next_obs": torch.randn(100, 10), "rew": torch.randn(100)}, batch_size=[100]
3434
)
3535

36-
data = union_tensor_dict(data1, data2)
36+
union_tensor_dict(data1, data2)
3737
with pytest.raises(AssertionError):
38-
data = union_tensor_dict(data1, data_with_copied_obs)
39-
38+
union_tensor_dict(data1, data_with_copied_obs)
39+
40+
41+
def test_union_numpy_dict():
42+
"""
43+
A comprehensive test suite for union_numpy_dict, covering standard use
44+
cases, N-dimensional arrays, object-dtype arrays, and NaN value handling.
45+
"""
46+
arr_3d = np.arange(8).reshape((2, 2, 2))
47+
union_numpy_dict({"a": arr_3d}, {"a": arr_3d})
48+
arr1 = np.array([1, "hello", np.array([2, 3])], dtype=object)
49+
arr2 = np.array([1, "hello", np.array([2, 3])], dtype=object)
50+
union_numpy_dict({"a": arr1}, {"a": arr2})
51+
# --- Test Case 1: The original test with mixed object/float types ---
52+
# This test case from the original test file is preserved.
4053
data = np.random.random(100)
41-
data2 = [float("nan") for _ in range(99)]
42-
data2.append("nan")
43-
data2 = np.array(data2, dtype=object)
44-
data3 = np.tile(data2, (2, 1))
45-
a = {"a": data, "b": data2, "c": data3}
46-
b = {"a": data, "b": data2, "c": data3}
47-
b_ = {"a": np.random.random(100)}
48-
union_numpy_dict(a, b)
54+
# This array intentionally mixes float('nan') and the string 'nan'
55+
nan_data = [float("nan") for _ in range(99)]
56+
nan_data.append("nan")
57+
nan_data_arr = np.array(nan_data, dtype=object)
58+
59+
dict1 = {"a": data, "b": nan_data_arr}
60+
dict2_same = {"a": data.copy(), "b": nan_data_arr.copy()}
61+
dict3_different = {"a": np.random.random(100)}
62+
63+
union_numpy_dict(dict1, dict2_same) # Should pass
64+
with pytest.raises(AssertionError):
65+
union_numpy_dict(dict1, dict3_different)
66+
67+
# --- Test Case 2: Standard 3D arrays (fixes the core bug) ---
68+
arr_3d = np.arange(24, dtype=np.int32).reshape((2, 3, 4))
69+
dict_3d_1 = {"nd_array": arr_3d}
70+
dict_3d_2_same = {"nd_array": arr_3d.copy()}
71+
dict_3d_3_different = {"nd_array": arr_3d + 1}
72+
73+
union_numpy_dict(dict_3d_1, dict_3d_2_same) # Should pass
74+
with pytest.raises(AssertionError, match="`nd_array` in tensor_dict1 and tensor_dict2 are not the same object."):
75+
union_numpy_dict(dict_3d_1, dict_3d_3_different)
76+
77+
# --- Test Case 3: Nested 2D and 4D object-dtype arrays ---
78+
sub_arr1 = np.array([1, 2])
79+
sub_arr2 = np.array([3.0, 4.0])
80+
# 2D object array
81+
arr_2d_obj = np.array([[sub_arr1, "text"], [sub_arr2, None]], dtype=object)
82+
arr_2d_obj_diff = np.array([[sub_arr1, "text"], [sub_arr2, "other"]], dtype=object)
83+
84+
union_numpy_dict({"data": arr_2d_obj}, {"data": arr_2d_obj.copy()}) # Should pass
85+
with pytest.raises(AssertionError):
86+
union_numpy_dict({"data": arr_2d_obj}, {"data": arr_2d_obj_diff})
87+
88+
# 4D object array to ensure deep recursion is robust
89+
arr_4d_obj = np.array([[[[sub_arr1]]], [[[sub_arr2]]]], dtype=object)
90+
arr_4d_obj_diff = np.array([[[[sub_arr1]]], [[[np.array([9, 9])]]]], dtype=object)
91+
92+
union_numpy_dict({"data": arr_4d_obj}, {"data": arr_4d_obj.copy()}) # Should pass
93+
with pytest.raises(AssertionError):
94+
union_numpy_dict({"data": arr_4d_obj}, {"data": arr_4d_obj_diff})
95+
96+
# --- Test Case 4: Explicit NaN value comparison ---
97+
# This verifies that our new _deep_equal logic correctly handles NaNs.
98+
nan_arr = np.array([1.0, np.nan, 3.0])
99+
dict_nan_1 = {"data": nan_arr}
100+
dict_nan_2_same = {"data": np.array([1.0, np.nan, 3.0])} # A new array with same values
101+
dict_nan_3_different_val = {"data": np.array([1.0, 2.0, 3.0])}
102+
dict_nan_4_different_pos = {"data": np.array([np.nan, 1.0, 3.0])}
103+
104+
# NaNs in the same position should be considered equal for merging.
105+
union_numpy_dict(dict_nan_1, dict_nan_2_same) # Should pass
106+
107+
with pytest.raises(AssertionError):
108+
union_numpy_dict(dict_nan_1, dict_nan_3_different_val)
109+
with pytest.raises(AssertionError):
110+
union_numpy_dict(dict_nan_1, dict_nan_4_different_pos)
111+
112+
# --- Test Case 5: Circular reference handling ---
113+
# Create two separate, but structurally identical, circular references.
114+
# This should pass without a RecursionError.
115+
circ_arr_1 = np.array([None], dtype=object)
116+
circ_arr_1[0] = circ_arr_1
117+
118+
circ_arr_2 = np.array([None], dtype=object)
119+
circ_arr_2[0] = circ_arr_2
120+
121+
union_numpy_dict({"data": circ_arr_1}, {"data": circ_arr_2}) # Should pass
122+
123+
# Create a circular reference and a non-circular one.
124+
# This should fail with an AssertionError because they are different.
125+
non_circ_arr = np.array([None], dtype=object)
126+
49127
with pytest.raises(AssertionError):
50-
union_numpy_dict(a, b_)
128+
union_numpy_dict({"data": circ_arr_1}, {"data": non_circ_arr})
51129

52130

53131
def test_tensor_dict_constructor():

verl/protocol.py

Lines changed: 67 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@
1919
import contextlib
2020
import copy
2121
import logging
22+
import math
2223
import os
2324
import pickle
2425
from dataclasses import dataclass, field
25-
from typing import Callable, Optional
26+
from typing import Any, Callable, Optional
2627

2728
import numpy as np
28-
import pandas as pd
2929
import ray
3030
import tensordict
3131
import torch
@@ -118,14 +118,77 @@ def union_tensor_dict(tensor_dict1: TensorDict, tensor_dict2: TensorDict) -> Ten
118118
return tensor_dict1
119119

120120

121+
def _array_equal(array1: np.ndarray, array2: np.ndarray, visited: set[int]) -> bool:
122+
"""
123+
Recursively compares two NumPy arrays for strict equality, with special
124+
handling for object-dtype arrays, NaN values, and circular references.
125+
This function assumes that the two arguments provided are NumPy arrays.
126+
127+
Args:
128+
array1: The first NumPy array.
129+
array2: The second NumPy array.
130+
131+
Returns:
132+
True if the arrays' dtypes, shapes, and all elements are equal.
133+
"""
134+
# Check dtype and shape first, as this is the fastest failure path.
135+
if array1.dtype != array2.dtype or array1.shape != array2.shape:
136+
return False
137+
138+
# For non-object dtypes, use NumPy's implementation with equal_nan=True.
139+
if array1.dtype != "object":
140+
return np.array_equal(array1, array2, equal_nan=True)
141+
142+
# For object-dtype arrays, we must recursively compare each element.
143+
# We delegate to _deep_equal to handle elements, as they could be any
144+
# type, including other nested arrays or NaNs.
145+
return all(_deep_equal(x, y, visited) for x, y in zip(array1.flat, array2.flat, strict=False))
146+
147+
148+
def _deep_equal(a: Any, b: Any, visited: set[int]) -> bool:
149+
"""
150+
Recursively performs a deep comparison between two Python objects.
151+
- Handles NaN values correctly (NaN == NaN evaluates to True).
152+
- Handling circular references.
153+
- Dispatches to _array_equal if both objects are NumPy arrays.
154+
- Otherwise, uses standard '==' comparison.
155+
"""
156+
if type(a) is not type(b):
157+
return False
158+
159+
# If we have seen this object ID before on this path, it's a cycle.
160+
# Since we already know the types match, we can safely assume this part
161+
# of the structure is equal.
162+
obj_id = id(a)
163+
if obj_id in visited:
164+
return True
165+
166+
visited.add(obj_id)
167+
168+
# Perform the specific comparison based on type
169+
result = False
170+
if isinstance(a, float) and math.isnan(a) and math.isnan(b):
171+
result = True
172+
elif isinstance(a, np.ndarray):
173+
# We know b is also an ndarray due to the initial type check
174+
result = _array_equal(a, b, visited)
175+
else:
176+
# Standard equality for all other types
177+
result = a == b
178+
179+
# Clean up the visited set on the way out of the recursion
180+
visited.remove(obj_id)
181+
return result
182+
183+
121184
def union_numpy_dict(tensor_dict1: dict[str, np.ndarray], tensor_dict2: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
122185
for key, val in tensor_dict2.items():
123186
if key in tensor_dict1:
124187
assert isinstance(tensor_dict2[key], np.ndarray)
125188
assert isinstance(tensor_dict1[key], np.ndarray)
126189
# to properly deal with nan and object type
127-
assert pd.DataFrame(tensor_dict2[key]).equals(pd.DataFrame(tensor_dict1[key])), (
128-
f"{key} in tensor_dict1 and tensor_dict2 are not the same object"
190+
assert _deep_equal(tensor_dict1[key], tensor_dict2[key], visited=set()), (
191+
f"`{key}` in tensor_dict1 and tensor_dict2 are not the same object."
129192
)
130193
tensor_dict1[key] = val
131194

0 commit comments

Comments
 (0)