diff --git a/test/auto_parallel/CMakeLists.txt b/test/auto_parallel/CMakeLists.txt index f5e4bbaceef2d8..ed8712609ef730 100644 --- a/test/auto_parallel/CMakeLists.txt +++ b/test/auto_parallel/CMakeLists.txt @@ -10,7 +10,6 @@ add_subdirectory(end_to_end) if(WITH_DISTRIBUTE AND WITH_GPU) # NOTE(zyl): unittests WITH multi cards and timeout - py_test_modules(test_co_shard MODULES test_co_shard) py_test_modules(test_converter MODULES test_converter) set_tests_properties(test_converter PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50) diff --git a/test/auto_parallel/co_shard.py b/test/auto_parallel/end_to_end/co_shard.py similarity index 100% rename from test/auto_parallel/co_shard.py rename to test/auto_parallel/end_to_end/co_shard.py diff --git a/test/auto_parallel/end_to_end/reshape_co_shard.py b/test/auto_parallel/end_to_end/reshape_co_shard.py index 69e91b5f6db1b5..0e04f0ed0d6531 100644 --- a/test/auto_parallel/end_to_end/reshape_co_shard.py +++ b/test/auto_parallel/end_to_end/reshape_co_shard.py @@ -11,186 +11,193 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, Any import numpy as np import paddle import paddle.distributed as dist +if TYPE_CHECKING: + from collections.abc import Callable -class TestReshapeCoShard: - def run_test_flatten(self): - a = paddle.rand([2, 12, 8], "float32") - mesh = dist.ProcessMesh([[0, 1], [2, 3]], dim_names=['x', 'y']) - - placements = [ - dist.Shard(0), - dist.Shard(1), - ] - idx = dist.get_rank() - input = dist.shard_tensor(a, mesh, placements) - out = paddle.reshape(input, [-1]) - np.testing.assert_equal(out.shape, [192]) - np.testing.assert_equal( - str(out.placements[0]), 'Shard(dim=0, shard_order=0)' - ) - np.testing.assert_equal(str(out.placements[1]), 'Replicate()') - new_slice = (idx // 2,) - np.testing.assert_equal( - out._local_value().numpy().flatten(), a[new_slice].numpy().flatten() - ) - - a = paddle.rand([4, 6, 8], "float32") - placements = [ - dist.Shard(0, shard_order=0), - dist.Shard(1, shard_order=1), - ] - input = dist.shard_tensor(a, mesh, placements) - out = paddle.reshape(input, [-1]) - np.testing.assert_equal(out.shape, [192]) - np.testing.assert_equal( - str(out.placements[0]), 'Shard(dim=0, shard_order=0)' - ) - np.testing.assert_equal( - str(out.placements[1]), 'Shard(dim=0, shard_order=1)' - ) - new_slice = (idx,) - np.testing.assert_equal( - out._local_value().numpy().flatten(), a[new_slice].numpy().flatten() - ) - - placements = [ - dist.Shard(1), - dist.Shard(2), - ] - input = dist.shard_tensor(a, mesh, placements) - out = paddle.reshape(input, [-1]) - np.testing.assert_equal(out.shape, [192]) - np.testing.assert_equal(str(out.placements[0]), 'Replicate()') - np.testing.assert_equal(str(out.placements[1]), 'Replicate()') - new_idx = slice(None) - np.testing.assert_equal( - out._local_value().numpy().flatten(), a[new_idx].numpy().flatten() - ) - - def run_test_split(self): - a = paddle.rand([192], dtype='float32') - mesh = dist.ProcessMesh([[0, 1], [2, 3]], dim_names=['x', 'y']) - placements = [ - dist.Shard(0, shard_order=0), - dist.Shard(0, shard_order=1), - ] - idx = dist.get_rank() - input = dist.shard_tensor(a, mesh, placements) - - out = paddle.reshape(input, [4, 6, -1]) - np.testing.assert_equal(out.shape, [4, 6, 8]) - np.testing.assert_equal( - str(out.placements[0]), 'Shard(dim=0, shard_order=0)' - ) - np.testing.assert_equal( - str(out.placements[1]), 'Shard(dim=0, shard_order=1)' - ) - new_slice = (idx,) - np.testing.assert_equal( - out._local_value().numpy().flatten(), a[new_slice].numpy().flatten() - ) - - input = dist.shard_tensor(a, mesh, placements) - out = paddle.reshape(input, [6, -1, 8]) - np.testing.assert_equal(out.shape, [6, 4, 8]) - np.testing.assert_equal(str(out.placements[0]), 'Replicate()') - np.testing.assert_equal(str(out.placements[1]), 'Replicate()') - new_slice = (slice(None),) - np.testing.assert_equal( - out._local_value().numpy().flatten(), a[new_slice].numpy().flatten() - ) - - def run_test_combination(self): - a = paddle.rand([4, 6, 8], "float32") - mesh = dist.ProcessMesh([[0, 1], [2, 3]], dim_names=['x', 'y']) - placements = [ - dist.Shard(0), - dist.Shard(1), - ] - idx = dist.get_rank() - input = dist.shard_tensor(a, mesh, placements) - out = paddle.reshape(input, [2, 12, 8]) - np.testing.assert_equal(out.shape, [2, 12, 8]) - np.testing.assert_equal( - str(out.placements[0]), 'Shard(dim=0, shard_order=0)' - ) - np.testing.assert_equal(str(out.placements[1]), 'Replicate()') - new_slice = (idx // 2,) - np.testing.assert_equal( - out._local_value().numpy().flatten(), a[new_slice].numpy().flatten() - ) - placements = [ - dist.Shard(0, shard_order=0), - dist.Shard(1, shard_order=1), - ] - input = dist.shard_tensor(a, mesh, placements) - out = paddle.reshape(input, [2, 12, 8]) - np.testing.assert_equal(out.shape, [2, 12, 8]) - np.testing.assert_equal(str(out.placements[0]), 'Replicate()') - np.testing.assert_equal(str(out.placements[1]), 'Replicate()') - new_slice = (slice(None),) - np.testing.assert_equal( - out._local_value().numpy().flatten(), a[new_slice].numpy().flatten() - ) +class ReshapeTestCase: + def __init__( + self, + input_shape: list[int], + input_placements: list[dist.Placement], + target_shape: list[int], + output_placements: list[dist.Placement], + slice_funtor: Callable[[int], Any] | None = None, + ): + self.input_shape = input_shape + self.input_placements = input_placements + self.target_shape = target_shape + self.output_placements = output_placements + self.slice_funtor = slice_funtor - input = dist.shard_tensor(a, mesh, placements) - out = paddle.reshape(input, [12, 2, 8]) - np.testing.assert_equal(out.shape, [12, 2, 8]) - np.testing.assert_equal( - str(out.placements[0]), 'Shard(dim=0, shard_order=0)' - ) - np.testing.assert_equal( - str(out.placements[1]), 'Shard(dim=0, shard_order=1)' - ) - new_slice = slice(idx % 4 * 3, idx % 4 * 3 + 3) - np.testing.assert_equal( - out._local_value().numpy().flatten(), a[new_slice].numpy().flatten() - ) - - placements = [ - dist.Shard(1), - dist.Shard(2), - ] - input = dist.shard_tensor(a, mesh, placements) - out = paddle.reshape(input, [8, 6, 4]) - np.testing.assert_equal(out.shape, [8, 6, 4]) - np.testing.assert_equal(str(out.placements[0]), 'Replicate()') - np.testing.assert_equal(str(out.placements[1]), 'Replicate()') - new_slice = (slice(None),) - np.testing.assert_equal( - out._local_value().numpy().flatten(), a[new_slice].numpy().flatten() - ) - placements = [ - dist.Shard(2, shard_order=0), - dist.Shard(2, shard_order=1), +class TestReshapeCoShard: + def setUp(self): + self.mesh = dist.ProcessMesh([[0, 1], [2, 3]], dim_names=['x', 'y']) + self.test_cases = [ + # test flatten + ReshapeTestCase( + [4, 6, 8], + [dist.Shard(0), dist.Shard(1)], + [192], + [dist.Shard(0, shard_order=0), dist.Shard(0, shard_order=1)], + lambda idx: (idx,), + ), + ReshapeTestCase( + [4, 6, 8], + [dist.Shard(1), dist.Shard(2)], + [192], + [dist.Replicate(), dist.Replicate()], + lambda idx: slice(None), + ), + ReshapeTestCase( + [4, 6, 8], + [dist.Shard(0, shard_order=0), dist.Shard(0, shard_order=1)], + [192], + [dist.Shard(0, shard_order=0), dist.Shard(0, shard_order=1)], + lambda idx: (idx,), + ), + ReshapeTestCase( + [2, 12, 8], + [dist.Shard(0), dist.Shard(1)], + [192], + [dist.Shard(0), dist.Replicate()], + lambda idx: (idx // 2,), + ), + # test split + ReshapeTestCase( + [192], + [dist.Shard(0, shard_order=0), dist.Shard(0, shard_order=1)], + [4, 6, 8], + [dist.Shard(0, shard_order=0), dist.Shard(0, shard_order=1)], + lambda idx: (idx,), + ), + ReshapeTestCase( + [192], + [dist.Shard(0, shard_order=0), dist.Shard(0, shard_order=1)], + [6, 4, 8], + [dist.Replicate(), dist.Replicate()], + lambda idx: slice(None), + ), + # test combination + ReshapeTestCase( + [4, 6, 8], + [dist.Shard(0), dist.Shard(1)], + [2, 12, 8], + [dist.Shard(0), dist.Replicate()], + lambda idx: (idx // 2,), + ), + ReshapeTestCase( + [4, 6, 8], + [dist.Shard(0, shard_order=0), dist.Shard(0, shard_order=1)], + [2, 12, 8], + [dist.Replicate(), dist.Replicate()], + lambda idx: slice(None), + ), + ReshapeTestCase( + [4, 6, 8], + [dist.Shard(0), dist.Shard(1)], + [12, 2, 8], + [dist.Shard(0, shard_order=0), dist.Shard(0, shard_order=1)], + lambda idx: slice(idx % 4 * 3, idx % 4 * 3 + 3), + ), + ReshapeTestCase( + [4, 6, 8], + [dist.Shard(0, shard_order=0), dist.Shard(0, shard_order=1)], + [12, 2, 8], + [dist.Shard(0, shard_order=0), dist.Shard(0, shard_order=1)], + lambda idx: slice(idx % 4 * 3, idx % 4 * 3 + 3), + ), + ReshapeTestCase( + [4, 6, 8], + [dist.Shard(0), dist.Shard(1)], + [8, 6, 4], + [dist.Shard(0, shard_order=0), dist.Shard(0, shard_order=1)], + lambda idx: slice(idx % 4 * 2, idx % 4 * 2 + 2), + ), + ReshapeTestCase( + [4, 6, 8], + [dist.Shard(1), dist.Shard(2)], + [8, 6, 4], + [dist.Replicate(), dist.Replicate()], + lambda idx: slice(None), + ), + ReshapeTestCase( + [4, 6, 8], + [dist.Shard(0), dist.Shard(2)], + [8, 6, 4], + [dist.Shard(0), dist.Replicate()], + lambda idx: (idx // 2, idx // 2 + 4), + ), + ReshapeTestCase( + [4, 6, 8], + [dist.Shard(0, shard_order=0), dist.Shard(0, shard_order=1)], + [8, 6, 4], + [dist.Shard(0, shard_order=0), dist.Shard(0, shard_order=1)], + lambda idx: slice(idx % 4 * 2, idx % 4 * 2 + 2), + ), + ReshapeTestCase( + [4, 6, 8], + [dist.Shard(2, shard_order=0), dist.Shard(2, shard_order=1)], + [24, 2, 4], + [dist.Replicate(), dist.Replicate()], + lambda idx: slice(None), + ), + ReshapeTestCase( + [4, 6, 8], + [dist.Shard(2, shard_order=0), dist.Shard(1, shard_order=1)], + [24, 4, 2], + [dist.Shard(2, shard_order=0), dist.Shard(1, shard_order=1)], + lambda idx: (slice(None), idx % 4, slice(None)), + ), ] - input = dist.shard_tensor(a, mesh, placements) - out = paddle.reshape(input, [24, 4, 2]) - np.testing.assert_equal(out.shape, [24, 4, 2]) - np.testing.assert_equal( - str(out.placements[0]), 'Shard(dim=1, shard_order=0)' - ) - np.testing.assert_equal( - str(out.placements[1]), 'Shard(dim=1, shard_order=1)' - ) - new_slice = (slice(None), dist.get_rank() % 4, slice(None)) - np.testing.assert_equal( - out._local_value().numpy().flatten(), a[new_slice].numpy().flatten() - ) - def run_test_case_main(self): - self.run_test_flatten() - self.run_test_split() - self.run_test_combination() + def run_test_case(self, test_case: ReshapeTestCase): + a = paddle.rand(test_case.input_shape, "float32") + input_placements = test_case.input_placements + input = dist.shard_tensor(a, self.mesh, input_placements) + out = paddle.reshape(input, test_case.target_shape) + case_info = f"input_shape: {test_case.input_shape}, input_placements: {input_placements}, target_shape: {test_case.target_shape}" + # Verify output shape + np.testing.assert_equal( + out.shape, + test_case.target_shape, + err_msg=f"Output shape mismatch when {case_info}. Expected: {test_case.target_shape}, Actual: {out.shape}", + ) + + # Verify placements + assert out.placements + for actual, expected in zip( + out.placements, test_case.output_placements + ): + np.testing.assert_equal( + actual, + expected, + err_msg=f"Output placements mismatch when {case_info}. Expected: {test_case.output_placements}, Actual: {out.placements}", + ) + # Verify local_value if given + if test_case.slice_funtor: + idx = dist.get_rank() + np.testing.assert_equal( + out._local_value().numpy().flatten(), + a[test_case.slice_funtor(idx)].numpy().flatten(), + err_msg=f"Local values mismatch when {case_info}.", + ) + + def run_all_tests(self): + self.setUp() + for test_case in self.test_cases: + self.run_test_case(test_case) if __name__ == '__main__': - TestReshapeCoShard().run_test_case_main() + TestReshapeCoShard().run_all_tests() diff --git a/test/auto_parallel/end_to_end/test_e2e_co_shard.py b/test/auto_parallel/end_to_end/test_e2e_co_shard.py index 605349da91e35d..a90e5194d15f70 100644 --- a/test/auto_parallel/end_to_end/test_e2e_co_shard.py +++ b/test/auto_parallel/end_to_end/test_e2e_co_shard.py @@ -21,7 +21,10 @@ class TestReshardE2E(test_base.CommunicationTestDistBase): def setUp(self): super().setUp(num_of_devices=4, timeout=120) - def test_reshard_co_shard(self): + def test_co_shard(self): + self.run_test_case("co_shard.py") + + def test_reshape_co_shard(self): self.run_test_case("reshape_co_shard.py") diff --git a/test/auto_parallel/test_co_shard.py b/test/auto_parallel/test_co_shard.py deleted file mode 100644 index c7bece78dcc2a7..00000000000000 --- a/test/auto_parallel/test_co_shard.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import collective.test_communication_api_base as test_base - - -class TestReshardRToS(test_base.CommunicationTestDistBase): - def setUp(self): - super().setUp(num_of_devices=4, timeout=120) - - def test_reshard_r_to_s(self): - self.run_test_case("co_shard.py") - - -if __name__ == "__main__": - unittest.main()