Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion test/auto_parallel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ add_subdirectory(end_to_end)
if(WITH_DISTRIBUTE AND WITH_GPU)

# NOTE(zyl): unittests WITH multi cards and timeout
py_test_modules(test_co_shard MODULES test_co_shard)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for deleting test_co_shard

py_test_modules(test_converter MODULES test_converter)
set_tests_properties(test_converter PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE"
TIMEOUT 50)
Expand Down
343 changes: 175 additions & 168 deletions test/auto_parallel/end_to_end/reshape_co_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,186 +11,193 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, Any

import numpy as np

import paddle
import paddle.distributed as dist

if TYPE_CHECKING:
from collections.abc import Callable

class TestReshapeCoShard:
def run_test_flatten(self):
a = paddle.rand([2, 12, 8], "float32")
mesh = dist.ProcessMesh([[0, 1], [2, 3]], dim_names=['x', 'y'])

placements = [
dist.Shard(0),
dist.Shard(1),
]
idx = dist.get_rank()
input = dist.shard_tensor(a, mesh, placements)
out = paddle.reshape(input, [-1])
np.testing.assert_equal(out.shape, [192])
np.testing.assert_equal(
str(out.placements[0]), 'Shard(dim=0, shard_order=0)'
)
np.testing.assert_equal(str(out.placements[1]), 'Replicate()')
new_slice = (idx // 2,)
np.testing.assert_equal(
out._local_value().numpy().flatten(), a[new_slice].numpy().flatten()
)

a = paddle.rand([4, 6, 8], "float32")
placements = [
dist.Shard(0, shard_order=0),
dist.Shard(1, shard_order=1),
]
input = dist.shard_tensor(a, mesh, placements)
out = paddle.reshape(input, [-1])
np.testing.assert_equal(out.shape, [192])
np.testing.assert_equal(
str(out.placements[0]), 'Shard(dim=0, shard_order=0)'
)
np.testing.assert_equal(
str(out.placements[1]), 'Shard(dim=0, shard_order=1)'
)
new_slice = (idx,)
np.testing.assert_equal(
out._local_value().numpy().flatten(), a[new_slice].numpy().flatten()
)

placements = [
dist.Shard(1),
dist.Shard(2),
]
input = dist.shard_tensor(a, mesh, placements)
out = paddle.reshape(input, [-1])
np.testing.assert_equal(out.shape, [192])
np.testing.assert_equal(str(out.placements[0]), 'Replicate()')
np.testing.assert_equal(str(out.placements[1]), 'Replicate()')
new_idx = slice(None)
np.testing.assert_equal(
out._local_value().numpy().flatten(), a[new_idx].numpy().flatten()
)

def run_test_split(self):
a = paddle.rand([192], dtype='float32')
mesh = dist.ProcessMesh([[0, 1], [2, 3]], dim_names=['x', 'y'])
placements = [
dist.Shard(0, shard_order=0),
dist.Shard(0, shard_order=1),
]
idx = dist.get_rank()
input = dist.shard_tensor(a, mesh, placements)

out = paddle.reshape(input, [4, 6, -1])
np.testing.assert_equal(out.shape, [4, 6, 8])
np.testing.assert_equal(
str(out.placements[0]), 'Shard(dim=0, shard_order=0)'
)
np.testing.assert_equal(
str(out.placements[1]), 'Shard(dim=0, shard_order=1)'
)
new_slice = (idx,)
np.testing.assert_equal(
out._local_value().numpy().flatten(), a[new_slice].numpy().flatten()
)

input = dist.shard_tensor(a, mesh, placements)
out = paddle.reshape(input, [6, -1, 8])
np.testing.assert_equal(out.shape, [6, 4, 8])
np.testing.assert_equal(str(out.placements[0]), 'Replicate()')
np.testing.assert_equal(str(out.placements[1]), 'Replicate()')
new_slice = (slice(None),)
np.testing.assert_equal(
out._local_value().numpy().flatten(), a[new_slice].numpy().flatten()
)

def run_test_combination(self):
a = paddle.rand([4, 6, 8], "float32")
mesh = dist.ProcessMesh([[0, 1], [2, 3]], dim_names=['x', 'y'])
placements = [
dist.Shard(0),
dist.Shard(1),
]
idx = dist.get_rank()
input = dist.shard_tensor(a, mesh, placements)
out = paddle.reshape(input, [2, 12, 8])
np.testing.assert_equal(out.shape, [2, 12, 8])
np.testing.assert_equal(
str(out.placements[0]), 'Shard(dim=0, shard_order=0)'
)
np.testing.assert_equal(str(out.placements[1]), 'Replicate()')
new_slice = (idx // 2,)
np.testing.assert_equal(
out._local_value().numpy().flatten(), a[new_slice].numpy().flatten()
)

placements = [
dist.Shard(0, shard_order=0),
dist.Shard(1, shard_order=1),
]
input = dist.shard_tensor(a, mesh, placements)
out = paddle.reshape(input, [2, 12, 8])
np.testing.assert_equal(out.shape, [2, 12, 8])
np.testing.assert_equal(str(out.placements[0]), 'Replicate()')
np.testing.assert_equal(str(out.placements[1]), 'Replicate()')
new_slice = (slice(None),)
np.testing.assert_equal(
out._local_value().numpy().flatten(), a[new_slice].numpy().flatten()
)
class ReshapeTestCase:
def __init__(
self,
input_shape: list[int],
input_placements: list[dist.Placement],
target_shape: list[int],
output_placements: list[dist.Placement],
slice_funtor: Callable[[int], Any] | None = None,
):
self.input_shape = input_shape
self.input_placements = input_placements
self.target_shape = target_shape
self.output_placements = output_placements
self.slice_funtor = slice_funtor

input = dist.shard_tensor(a, mesh, placements)
out = paddle.reshape(input, [12, 2, 8])
np.testing.assert_equal(out.shape, [12, 2, 8])
np.testing.assert_equal(
str(out.placements[0]), 'Shard(dim=0, shard_order=0)'
)
np.testing.assert_equal(
str(out.placements[1]), 'Shard(dim=0, shard_order=1)'
)
new_slice = slice(idx % 4 * 3, idx % 4 * 3 + 3)
np.testing.assert_equal(
out._local_value().numpy().flatten(), a[new_slice].numpy().flatten()
)

placements = [
dist.Shard(1),
dist.Shard(2),
]
input = dist.shard_tensor(a, mesh, placements)
out = paddle.reshape(input, [8, 6, 4])
np.testing.assert_equal(out.shape, [8, 6, 4])
np.testing.assert_equal(str(out.placements[0]), 'Replicate()')
np.testing.assert_equal(str(out.placements[1]), 'Replicate()')
new_slice = (slice(None),)
np.testing.assert_equal(
out._local_value().numpy().flatten(), a[new_slice].numpy().flatten()
)

placements = [
dist.Shard(2, shard_order=0),
dist.Shard(2, shard_order=1),
class TestReshapeCoShard:
def setUp(self):
self.mesh = dist.ProcessMesh([[0, 1], [2, 3]], dim_names=['x', 'y'])
self.test_cases = [
# test flatten
ReshapeTestCase(
[4, 6, 8],
[dist.Shard(0), dist.Shard(1)],
[192],
[dist.Shard(0, shard_order=0), dist.Shard(0, shard_order=1)],
lambda idx: (idx,),
),
ReshapeTestCase(
[4, 6, 8],
[dist.Shard(1), dist.Shard(2)],
[192],
[dist.Replicate(), dist.Replicate()],
lambda idx: slice(None),
),
ReshapeTestCase(
[4, 6, 8],
[dist.Shard(0, shard_order=0), dist.Shard(0, shard_order=1)],
[192],
[dist.Shard(0, shard_order=0), dist.Shard(0, shard_order=1)],
lambda idx: (idx,),
),
ReshapeTestCase(
[2, 12, 8],
[dist.Shard(0), dist.Shard(1)],
[192],
[dist.Shard(0), dist.Replicate()],
lambda idx: (idx // 2,),
),
# test split
ReshapeTestCase(
[192],
[dist.Shard(0, shard_order=0), dist.Shard(0, shard_order=1)],
[4, 6, 8],
[dist.Shard(0, shard_order=0), dist.Shard(0, shard_order=1)],
lambda idx: (idx,),
),
ReshapeTestCase(
[192],
[dist.Shard(0, shard_order=0), dist.Shard(0, shard_order=1)],
[6, 4, 8],
[dist.Replicate(), dist.Replicate()],
lambda idx: slice(None),
),
# test combination
ReshapeTestCase(
[4, 6, 8],
[dist.Shard(0), dist.Shard(1)],
[2, 12, 8],
[dist.Shard(0), dist.Replicate()],
lambda idx: (idx // 2,),
),
ReshapeTestCase(
[4, 6, 8],
[dist.Shard(0, shard_order=0), dist.Shard(0, shard_order=1)],
[2, 12, 8],
[dist.Replicate(), dist.Replicate()],
lambda idx: slice(None),
),
ReshapeTestCase(
[4, 6, 8],
[dist.Shard(0), dist.Shard(1)],
[12, 2, 8],
[dist.Shard(0, shard_order=0), dist.Shard(0, shard_order=1)],
lambda idx: slice(idx % 4 * 3, idx % 4 * 3 + 3),
),
ReshapeTestCase(
[4, 6, 8],
[dist.Shard(0, shard_order=0), dist.Shard(0, shard_order=1)],
[12, 2, 8],
[dist.Shard(0, shard_order=0), dist.Shard(0, shard_order=1)],
lambda idx: slice(idx % 4 * 3, idx % 4 * 3 + 3),
),
ReshapeTestCase(
[4, 6, 8],
[dist.Shard(0), dist.Shard(1)],
[8, 6, 4],
[dist.Shard(0, shard_order=0), dist.Shard(0, shard_order=1)],
lambda idx: slice(idx % 4 * 2, idx % 4 * 2 + 2),
),
ReshapeTestCase(
[4, 6, 8],
[dist.Shard(1), dist.Shard(2)],
[8, 6, 4],
[dist.Replicate(), dist.Replicate()],
lambda idx: slice(None),
),
ReshapeTestCase(
[4, 6, 8],
[dist.Shard(0), dist.Shard(2)],
[8, 6, 4],
[dist.Shard(0), dist.Replicate()],
lambda idx: (idx // 2, idx // 2 + 4),
),
ReshapeTestCase(
[4, 6, 8],
[dist.Shard(0, shard_order=0), dist.Shard(0, shard_order=1)],
[8, 6, 4],
[dist.Shard(0, shard_order=0), dist.Shard(0, shard_order=1)],
lambda idx: slice(idx % 4 * 2, idx % 4 * 2 + 2),
),
ReshapeTestCase(
[4, 6, 8],
[dist.Shard(2, shard_order=0), dist.Shard(2, shard_order=1)],
[24, 2, 4],
[dist.Replicate(), dist.Replicate()],
lambda idx: slice(None),
),
ReshapeTestCase(
[4, 6, 8],
[dist.Shard(2, shard_order=0), dist.Shard(1, shard_order=1)],
[24, 4, 2],
[dist.Shard(2, shard_order=0), dist.Shard(1, shard_order=1)],
lambda idx: (slice(None), idx % 4, slice(None)),
),
]
input = dist.shard_tensor(a, mesh, placements)
out = paddle.reshape(input, [24, 4, 2])
np.testing.assert_equal(out.shape, [24, 4, 2])
np.testing.assert_equal(
str(out.placements[0]), 'Shard(dim=1, shard_order=0)'
)
np.testing.assert_equal(
str(out.placements[1]), 'Shard(dim=1, shard_order=1)'
)
new_slice = (slice(None), dist.get_rank() % 4, slice(None))
np.testing.assert_equal(
out._local_value().numpy().flatten(), a[new_slice].numpy().flatten()
)

def run_test_case_main(self):
self.run_test_flatten()
self.run_test_split()
self.run_test_combination()
def run_test_case(self, test_case: ReshapeTestCase):
a = paddle.rand(test_case.input_shape, "float32")
input_placements = test_case.input_placements
input = dist.shard_tensor(a, self.mesh, input_placements)
out = paddle.reshape(input, test_case.target_shape)
case_info = f"input_shape: {test_case.input_shape}, input_placements: {input_placements}, target_shape: {test_case.target_shape}"
# Verify output shape
np.testing.assert_equal(
out.shape,
test_case.target_shape,
err_msg=f"Output shape mismatch when {case_info}. Expected: {test_case.target_shape}, Actual: {out.shape}",
)

# Verify placements
assert out.placements
for actual, expected in zip(
out.placements, test_case.output_placements
):
np.testing.assert_equal(
actual,
expected,
err_msg=f"Output placements mismatch when {case_info}. Expected: {test_case.output_placements}, Actual: {out.placements}",
)
# Verify local_value if given
if test_case.slice_funtor:
idx = dist.get_rank()
np.testing.assert_equal(
out._local_value().numpy().flatten(),
a[test_case.slice_funtor(idx)].numpy().flatten(),
err_msg=f"Local values mismatch when {case_info}.",
)

def run_all_tests(self):
self.setUp()
for test_case in self.test_cases:
self.run_test_case(test_case)


if __name__ == '__main__':
TestReshapeCoShard().run_test_case_main()
TestReshapeCoShard().run_all_tests()
5 changes: 4 additions & 1 deletion test/auto_parallel/end_to_end/test_e2e_co_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ class TestReshardE2E(test_base.CommunicationTestDistBase):
def setUp(self):
super().setUp(num_of_devices=4, timeout=120)

def test_reshard_co_shard(self):
def test_co_shard(self):
self.run_test_case("co_shard.py")

def test_reshape_co_shard(self):
self.run_test_case("reshape_co_shard.py")


Expand Down
Loading