Skip to content

Commit e5c11eb

Browse files
authored
Refine reshape e2e tests (#74717)
1 parent a4db267 commit e5c11eb

File tree

5 files changed

+179
-199
lines changed

5 files changed

+179
-199
lines changed

test/auto_parallel/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ add_subdirectory(end_to_end)
1010
if(WITH_DISTRIBUTE AND WITH_GPU)
1111

1212
# NOTE(zyl): unittests WITH multi cards and timeout
13-
py_test_modules(test_co_shard MODULES test_co_shard)
1413
py_test_modules(test_converter MODULES test_converter)
1514
set_tests_properties(test_converter PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE"
1615
TIMEOUT 50)
File renamed without changes.

test/auto_parallel/end_to_end/reshape_co_shard.py

Lines changed: 175 additions & 168 deletions
Original file line numberDiff line numberDiff line change
@@ -11,186 +11,193 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from __future__ import annotations
15+
16+
from typing import TYPE_CHECKING, Any
1417

1518
import numpy as np
1619

1720
import paddle
1821
import paddle.distributed as dist
1922

23+
if TYPE_CHECKING:
24+
from collections.abc import Callable
2025

21-
class TestReshapeCoShard:
22-
def run_test_flatten(self):
23-
a = paddle.rand([2, 12, 8], "float32")
24-
mesh = dist.ProcessMesh([[0, 1], [2, 3]], dim_names=['x', 'y'])
25-
26-
placements = [
27-
dist.Shard(0),
28-
dist.Shard(1),
29-
]
30-
idx = dist.get_rank()
31-
input = dist.shard_tensor(a, mesh, placements)
32-
out = paddle.reshape(input, [-1])
33-
np.testing.assert_equal(out.shape, [192])
34-
np.testing.assert_equal(
35-
str(out.placements[0]), 'Shard(dim=0, shard_order=0)'
36-
)
37-
np.testing.assert_equal(str(out.placements[1]), 'Replicate()')
38-
new_slice = (idx // 2,)
39-
np.testing.assert_equal(
40-
out._local_value().numpy().flatten(), a[new_slice].numpy().flatten()
41-
)
42-
43-
a = paddle.rand([4, 6, 8], "float32")
44-
placements = [
45-
dist.Shard(0, shard_order=0),
46-
dist.Shard(1, shard_order=1),
47-
]
48-
input = dist.shard_tensor(a, mesh, placements)
49-
out = paddle.reshape(input, [-1])
50-
np.testing.assert_equal(out.shape, [192])
51-
np.testing.assert_equal(
52-
str(out.placements[0]), 'Shard(dim=0, shard_order=0)'
53-
)
54-
np.testing.assert_equal(
55-
str(out.placements[1]), 'Shard(dim=0, shard_order=1)'
56-
)
57-
new_slice = (idx,)
58-
np.testing.assert_equal(
59-
out._local_value().numpy().flatten(), a[new_slice].numpy().flatten()
60-
)
61-
62-
placements = [
63-
dist.Shard(1),
64-
dist.Shard(2),
65-
]
66-
input = dist.shard_tensor(a, mesh, placements)
67-
out = paddle.reshape(input, [-1])
68-
np.testing.assert_equal(out.shape, [192])
69-
np.testing.assert_equal(str(out.placements[0]), 'Replicate()')
70-
np.testing.assert_equal(str(out.placements[1]), 'Replicate()')
71-
new_idx = slice(None)
72-
np.testing.assert_equal(
73-
out._local_value().numpy().flatten(), a[new_idx].numpy().flatten()
74-
)
75-
76-
def run_test_split(self):
77-
a = paddle.rand([192], dtype='float32')
78-
mesh = dist.ProcessMesh([[0, 1], [2, 3]], dim_names=['x', 'y'])
79-
placements = [
80-
dist.Shard(0, shard_order=0),
81-
dist.Shard(0, shard_order=1),
82-
]
83-
idx = dist.get_rank()
84-
input = dist.shard_tensor(a, mesh, placements)
85-
86-
out = paddle.reshape(input, [4, 6, -1])
87-
np.testing.assert_equal(out.shape, [4, 6, 8])
88-
np.testing.assert_equal(
89-
str(out.placements[0]), 'Shard(dim=0, shard_order=0)'
90-
)
91-
np.testing.assert_equal(
92-
str(out.placements[1]), 'Shard(dim=0, shard_order=1)'
93-
)
94-
new_slice = (idx,)
95-
np.testing.assert_equal(
96-
out._local_value().numpy().flatten(), a[new_slice].numpy().flatten()
97-
)
98-
99-
input = dist.shard_tensor(a, mesh, placements)
100-
out = paddle.reshape(input, [6, -1, 8])
101-
np.testing.assert_equal(out.shape, [6, 4, 8])
102-
np.testing.assert_equal(str(out.placements[0]), 'Replicate()')
103-
np.testing.assert_equal(str(out.placements[1]), 'Replicate()')
104-
new_slice = (slice(None),)
105-
np.testing.assert_equal(
106-
out._local_value().numpy().flatten(), a[new_slice].numpy().flatten()
107-
)
108-
109-
def run_test_combination(self):
110-
a = paddle.rand([4, 6, 8], "float32")
111-
mesh = dist.ProcessMesh([[0, 1], [2, 3]], dim_names=['x', 'y'])
112-
placements = [
113-
dist.Shard(0),
114-
dist.Shard(1),
115-
]
116-
idx = dist.get_rank()
117-
input = dist.shard_tensor(a, mesh, placements)
118-
out = paddle.reshape(input, [2, 12, 8])
119-
np.testing.assert_equal(out.shape, [2, 12, 8])
120-
np.testing.assert_equal(
121-
str(out.placements[0]), 'Shard(dim=0, shard_order=0)'
122-
)
123-
np.testing.assert_equal(str(out.placements[1]), 'Replicate()')
124-
new_slice = (idx // 2,)
125-
np.testing.assert_equal(
126-
out._local_value().numpy().flatten(), a[new_slice].numpy().flatten()
127-
)
12826

129-
placements = [
130-
dist.Shard(0, shard_order=0),
131-
dist.Shard(1, shard_order=1),
132-
]
133-
input = dist.shard_tensor(a, mesh, placements)
134-
out = paddle.reshape(input, [2, 12, 8])
135-
np.testing.assert_equal(out.shape, [2, 12, 8])
136-
np.testing.assert_equal(str(out.placements[0]), 'Replicate()')
137-
np.testing.assert_equal(str(out.placements[1]), 'Replicate()')
138-
new_slice = (slice(None),)
139-
np.testing.assert_equal(
140-
out._local_value().numpy().flatten(), a[new_slice].numpy().flatten()
141-
)
27+
class ReshapeTestCase:
28+
def __init__(
29+
self,
30+
input_shape: list[int],
31+
input_placements: list[dist.Placement],
32+
target_shape: list[int],
33+
output_placements: list[dist.Placement],
34+
slice_funtor: Callable[[int], Any] | None = None,
35+
):
36+
self.input_shape = input_shape
37+
self.input_placements = input_placements
38+
self.target_shape = target_shape
39+
self.output_placements = output_placements
40+
self.slice_funtor = slice_funtor
14241

143-
input = dist.shard_tensor(a, mesh, placements)
144-
out = paddle.reshape(input, [12, 2, 8])
145-
np.testing.assert_equal(out.shape, [12, 2, 8])
146-
np.testing.assert_equal(
147-
str(out.placements[0]), 'Shard(dim=0, shard_order=0)'
148-
)
149-
np.testing.assert_equal(
150-
str(out.placements[1]), 'Shard(dim=0, shard_order=1)'
151-
)
152-
new_slice = slice(idx % 4 * 3, idx % 4 * 3 + 3)
153-
np.testing.assert_equal(
154-
out._local_value().numpy().flatten(), a[new_slice].numpy().flatten()
155-
)
156-
157-
placements = [
158-
dist.Shard(1),
159-
dist.Shard(2),
160-
]
161-
input = dist.shard_tensor(a, mesh, placements)
162-
out = paddle.reshape(input, [8, 6, 4])
163-
np.testing.assert_equal(out.shape, [8, 6, 4])
164-
np.testing.assert_equal(str(out.placements[0]), 'Replicate()')
165-
np.testing.assert_equal(str(out.placements[1]), 'Replicate()')
166-
new_slice = (slice(None),)
167-
np.testing.assert_equal(
168-
out._local_value().numpy().flatten(), a[new_slice].numpy().flatten()
169-
)
17042

171-
placements = [
172-
dist.Shard(2, shard_order=0),
173-
dist.Shard(2, shard_order=1),
43+
class TestReshapeCoShard:
44+
def setUp(self):
45+
self.mesh = dist.ProcessMesh([[0, 1], [2, 3]], dim_names=['x', 'y'])
46+
self.test_cases = [
47+
# test flatten
48+
ReshapeTestCase(
49+
[4, 6, 8],
50+
[dist.Shard(0), dist.Shard(1)],
51+
[192],
52+
[dist.Shard(0, shard_order=0), dist.Shard(0, shard_order=1)],
53+
lambda idx: (idx,),
54+
),
55+
ReshapeTestCase(
56+
[4, 6, 8],
57+
[dist.Shard(1), dist.Shard(2)],
58+
[192],
59+
[dist.Replicate(), dist.Replicate()],
60+
lambda idx: slice(None),
61+
),
62+
ReshapeTestCase(
63+
[4, 6, 8],
64+
[dist.Shard(0, shard_order=0), dist.Shard(0, shard_order=1)],
65+
[192],
66+
[dist.Shard(0, shard_order=0), dist.Shard(0, shard_order=1)],
67+
lambda idx: (idx,),
68+
),
69+
ReshapeTestCase(
70+
[2, 12, 8],
71+
[dist.Shard(0), dist.Shard(1)],
72+
[192],
73+
[dist.Shard(0), dist.Replicate()],
74+
lambda idx: (idx // 2,),
75+
),
76+
# test split
77+
ReshapeTestCase(
78+
[192],
79+
[dist.Shard(0, shard_order=0), dist.Shard(0, shard_order=1)],
80+
[4, 6, 8],
81+
[dist.Shard(0, shard_order=0), dist.Shard(0, shard_order=1)],
82+
lambda idx: (idx,),
83+
),
84+
ReshapeTestCase(
85+
[192],
86+
[dist.Shard(0, shard_order=0), dist.Shard(0, shard_order=1)],
87+
[6, 4, 8],
88+
[dist.Replicate(), dist.Replicate()],
89+
lambda idx: slice(None),
90+
),
91+
# test combination
92+
ReshapeTestCase(
93+
[4, 6, 8],
94+
[dist.Shard(0), dist.Shard(1)],
95+
[2, 12, 8],
96+
[dist.Shard(0), dist.Replicate()],
97+
lambda idx: (idx // 2,),
98+
),
99+
ReshapeTestCase(
100+
[4, 6, 8],
101+
[dist.Shard(0, shard_order=0), dist.Shard(0, shard_order=1)],
102+
[2, 12, 8],
103+
[dist.Replicate(), dist.Replicate()],
104+
lambda idx: slice(None),
105+
),
106+
ReshapeTestCase(
107+
[4, 6, 8],
108+
[dist.Shard(0), dist.Shard(1)],
109+
[12, 2, 8],
110+
[dist.Shard(0, shard_order=0), dist.Shard(0, shard_order=1)],
111+
lambda idx: slice(idx % 4 * 3, idx % 4 * 3 + 3),
112+
),
113+
ReshapeTestCase(
114+
[4, 6, 8],
115+
[dist.Shard(0, shard_order=0), dist.Shard(0, shard_order=1)],
116+
[12, 2, 8],
117+
[dist.Shard(0, shard_order=0), dist.Shard(0, shard_order=1)],
118+
lambda idx: slice(idx % 4 * 3, idx % 4 * 3 + 3),
119+
),
120+
ReshapeTestCase(
121+
[4, 6, 8],
122+
[dist.Shard(0), dist.Shard(1)],
123+
[8, 6, 4],
124+
[dist.Shard(0, shard_order=0), dist.Shard(0, shard_order=1)],
125+
lambda idx: slice(idx % 4 * 2, idx % 4 * 2 + 2),
126+
),
127+
ReshapeTestCase(
128+
[4, 6, 8],
129+
[dist.Shard(1), dist.Shard(2)],
130+
[8, 6, 4],
131+
[dist.Replicate(), dist.Replicate()],
132+
lambda idx: slice(None),
133+
),
134+
ReshapeTestCase(
135+
[4, 6, 8],
136+
[dist.Shard(0), dist.Shard(2)],
137+
[8, 6, 4],
138+
[dist.Shard(0), dist.Replicate()],
139+
lambda idx: (idx // 2, idx // 2 + 4),
140+
),
141+
ReshapeTestCase(
142+
[4, 6, 8],
143+
[dist.Shard(0, shard_order=0), dist.Shard(0, shard_order=1)],
144+
[8, 6, 4],
145+
[dist.Shard(0, shard_order=0), dist.Shard(0, shard_order=1)],
146+
lambda idx: slice(idx % 4 * 2, idx % 4 * 2 + 2),
147+
),
148+
ReshapeTestCase(
149+
[4, 6, 8],
150+
[dist.Shard(2, shard_order=0), dist.Shard(2, shard_order=1)],
151+
[24, 2, 4],
152+
[dist.Replicate(), dist.Replicate()],
153+
lambda idx: slice(None),
154+
),
155+
ReshapeTestCase(
156+
[4, 6, 8],
157+
[dist.Shard(2, shard_order=0), dist.Shard(1, shard_order=1)],
158+
[24, 4, 2],
159+
[dist.Shard(2, shard_order=0), dist.Shard(1, shard_order=1)],
160+
lambda idx: (slice(None), idx % 4, slice(None)),
161+
),
174162
]
175-
input = dist.shard_tensor(a, mesh, placements)
176-
out = paddle.reshape(input, [24, 4, 2])
177-
np.testing.assert_equal(out.shape, [24, 4, 2])
178-
np.testing.assert_equal(
179-
str(out.placements[0]), 'Shard(dim=1, shard_order=0)'
180-
)
181-
np.testing.assert_equal(
182-
str(out.placements[1]), 'Shard(dim=1, shard_order=1)'
183-
)
184-
new_slice = (slice(None), dist.get_rank() % 4, slice(None))
185-
np.testing.assert_equal(
186-
out._local_value().numpy().flatten(), a[new_slice].numpy().flatten()
187-
)
188163

189-
def run_test_case_main(self):
190-
self.run_test_flatten()
191-
self.run_test_split()
192-
self.run_test_combination()
164+
def run_test_case(self, test_case: ReshapeTestCase):
165+
a = paddle.rand(test_case.input_shape, "float32")
166+
input_placements = test_case.input_placements
167+
input = dist.shard_tensor(a, self.mesh, input_placements)
168+
out = paddle.reshape(input, test_case.target_shape)
169+
case_info = f"input_shape: {test_case.input_shape}, input_placements: {input_placements}, target_shape: {test_case.target_shape}"
170+
# Verify output shape
171+
np.testing.assert_equal(
172+
out.shape,
173+
test_case.target_shape,
174+
err_msg=f"Output shape mismatch when {case_info}. Expected: {test_case.target_shape}, Actual: {out.shape}",
175+
)
176+
177+
# Verify placements
178+
assert out.placements
179+
for actual, expected in zip(
180+
out.placements, test_case.output_placements
181+
):
182+
np.testing.assert_equal(
183+
actual,
184+
expected,
185+
err_msg=f"Output placements mismatch when {case_info}. Expected: {test_case.output_placements}, Actual: {out.placements}",
186+
)
187+
# Verify local_value if given
188+
if test_case.slice_funtor:
189+
idx = dist.get_rank()
190+
np.testing.assert_equal(
191+
out._local_value().numpy().flatten(),
192+
a[test_case.slice_funtor(idx)].numpy().flatten(),
193+
err_msg=f"Local values mismatch when {case_info}.",
194+
)
195+
196+
def run_all_tests(self):
197+
self.setUp()
198+
for test_case in self.test_cases:
199+
self.run_test_case(test_case)
193200

194201

195202
if __name__ == '__main__':
196-
TestReshapeCoShard().run_test_case_main()
203+
TestReshapeCoShard().run_all_tests()

test/auto_parallel/end_to_end/test_e2e_co_shard.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@ class TestReshardE2E(test_base.CommunicationTestDistBase):
2121
def setUp(self):
2222
super().setUp(num_of_devices=4, timeout=120)
2323

24-
def test_reshard_co_shard(self):
24+
def test_co_shard(self):
25+
self.run_test_case("co_shard.py")
26+
27+
def test_reshape_co_shard(self):
2528
self.run_test_case("reshape_co_shard.py")
2629

2730

0 commit comments

Comments
 (0)