Skip to content

Commit 98d5956

Browse files
authored
[AutoParallel]Fix get_group method of processmesh (#73099)
* fix bug -- get_group重复创建通信组 * 添加fleet类中的self._hcg成员变量的初始化,用于判断此时是否存在hybrid_communicate_group,同时增加id比较,确认是同一个id * 新增hcg判断方法 * 修改_hcg属性的判断方式 * rerun CI * rerun CI * Remove the redundant variables * merge the different_hybrid_configs test * fix the code style
1 parent 78b6114 commit 98d5956

File tree

5 files changed

+304
-4
lines changed

5 files changed

+304
-4
lines changed

python/paddle/distributed/auto_parallel/process_mesh.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
import numpy as np
2222

2323
import paddle
24+
from paddle.distributed import fleet
25+
from paddle.distributed.collective import _get_group_map
2426
from paddle.distributed.communication.group import is_initialized
2527
from paddle.framework import core
2628

@@ -442,8 +444,29 @@ def get_group(
442444
f"{dim_name} not in the dimension names {self._dim_names}"
443445
)
444446
else:
445-
pg = paddle.distributed.new_group(self._process_ids)
446-
return pg
447+
if hasattr(fleet.fleet, "_hcg"):
448+
hcg = fleet.get_hybrid_communicate_group()
449+
if hcg is not None:
450+
451+
parallel_group_map = {
452+
"pp": hcg.get_pipe_parallel_group,
453+
"dp": hcg.get_data_parallel_group,
454+
"mp": hcg.get_model_parallel_group,
455+
"sep": hcg.get_sep_parallel_group,
456+
"sharding": hcg.get_sharding_parallel_group,
457+
}
458+
459+
if dim_name not in parallel_group_map:
460+
raise ValueError(
461+
f"{dim_name} is not a valid dim name."
462+
)
463+
464+
return parallel_group_map[dim_name]()
465+
group_map = _get_group_map()
466+
for group in group_map.values():
467+
if set(group.ranks) == set(self._process_ids):
468+
return group
469+
return paddle.distributed.new_group(self._process_ids)
447470
else:
448471
if dim_name not in self._dim_names:
449472
raise ValueError(

test/auto_parallel/hybrid_strategy/CMakeLists.txt

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,14 @@ if((WITH_GPU) AND (LINUX))
173173
py_test_modules(
174174
test_process_mesh MODULES test_process_mesh ENVS
175175
"http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python")
176-
set_tests_properties(test_process_mesh PROPERTIES TIMEOUT "60" LABELS
176+
set_tests_properties(test_process_mesh PROPERTIES TIMEOUT "150" LABELS
177177
"RUN_TYPE=HYBRID")
178178
endif()
179+
if((WITH_GPU) AND (LINUX))
180+
py_test_modules(
181+
test_get_group_in_different_hybrid_configs MODULES
182+
test_get_group_in_different_hybrid_configs ENVS
183+
"http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python")
184+
set_tests_properties(test_get_group_in_different_hybrid_configs
185+
PROPERTIES TIMEOUT "150" LABELS "RUN_TYPE=HYBRID")
186+
endif()

test/auto_parallel/hybrid_strategy/process_mesh_demo_unittest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def test_get_group(self):
9999
assert isinstance(
100100
group_1d_with_name, dist.communication.group.Group
101101
)
102-
102+
assert group_1d_with_name.id == group_1d.id
103103
# Test case 3: Single dimension mesh with wrong dim_name
104104
try:
105105
mesh_1d.get_group(dim_name="wrong_name")
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import collective.test_communication_api_base as test_base
18+
19+
20+
class TestProcessMeshDPGroupConsistency(test_base.CommunicationTestDistBase):
21+
def setUp(self):
22+
super().setUp(num_of_devices=2, timeout=200, nnode=1)
23+
24+
def test_dp_parallel(self):
25+
"""Test data parallel group creation and consistency"""
26+
_default_envs = {
27+
"dp": "2",
28+
"mp": "1",
29+
"pp": "1",
30+
"parallel_type": "dp",
31+
"FLAGS_embedding_deterministic": "1",
32+
"FLAGS_cudnn_deterministic": "1",
33+
}
34+
_changeable_envs = {
35+
"backend": ["gpu"],
36+
}
37+
envs_list = test_base.gen_product_envs_list(
38+
_default_envs, _changeable_envs
39+
)
40+
for envs in envs_list:
41+
self.run_test_case(
42+
"test_process_mesh_group_consistency.py",
43+
user_defined_envs=envs,
44+
)
45+
46+
47+
class TestProcessMeshMPGroupConsistency(test_base.CommunicationTestDistBase):
48+
def setUp(self):
49+
super().setUp(num_of_devices=2, timeout=200, nnode=1)
50+
51+
def test_mp_parallel(self):
52+
"""Test model parallel group creation and consistency"""
53+
_default_envs = {
54+
"dp": "1",
55+
"mp": "2",
56+
"pp": "1",
57+
"parallel_type": "mp",
58+
"FLAGS_embedding_deterministic": "1",
59+
"FLAGS_cudnn_deterministic": "1",
60+
}
61+
_changeable_envs = {
62+
"backend": ["gpu"],
63+
}
64+
envs_list = test_base.gen_product_envs_list(
65+
_default_envs, _changeable_envs
66+
)
67+
for envs in envs_list:
68+
self.run_test_case(
69+
"test_process_mesh_group_consistency.py",
70+
user_defined_envs=envs,
71+
)
72+
73+
74+
class TestProcessMeshPPGroupConsistency(test_base.CommunicationTestDistBase):
75+
def setUp(self):
76+
super().setUp(num_of_devices=2, timeout=200, nnode=1)
77+
78+
def test_pp_parallel(self):
79+
"""Test pipeline parallel group creation and consistency"""
80+
_default_envs = {
81+
"dp": "1",
82+
"mp": "1",
83+
"pp": "2",
84+
"parallel_type": "pp",
85+
"FLAGS_embedding_deterministic": "1",
86+
"FLAGS_cudnn_deterministic": "1",
87+
}
88+
_changeable_envs = {
89+
"backend": ["gpu"],
90+
}
91+
envs_list = test_base.gen_product_envs_list(
92+
_default_envs, _changeable_envs
93+
)
94+
for envs in envs_list:
95+
self.run_test_case(
96+
"test_process_mesh_group_consistency.py",
97+
user_defined_envs=envs,
98+
)
99+
100+
101+
class TestProcessMeshSEPGroupConsistency(test_base.CommunicationTestDistBase):
102+
def setUp(self):
103+
super().setUp(num_of_devices=2, timeout=200, nnode=1)
104+
105+
def test_sep_parallel(self):
106+
"""Test sequence parallel group creation and consistency"""
107+
_default_envs = {
108+
"dp": "1",
109+
"mp": "1",
110+
"pp": "1",
111+
"sep": "2",
112+
"sharding": "1",
113+
"parallel_type": "sep",
114+
"FLAGS_embedding_deterministic": "1",
115+
"FLAGS_cudnn_deterministic": "1",
116+
}
117+
_changeable_envs = {
118+
"backend": ["gpu"],
119+
}
120+
envs_list = test_base.gen_product_envs_list(
121+
_default_envs, _changeable_envs
122+
)
123+
for envs in envs_list:
124+
self.run_test_case(
125+
"test_process_mesh_group_consistency.py",
126+
user_defined_envs=envs,
127+
)
128+
129+
130+
class TestProcessMeshShardingGroupConsistency(
131+
test_base.CommunicationTestDistBase
132+
):
133+
def setUp(self):
134+
super().setUp(num_of_devices=2, timeout=200, nnode=1)
135+
136+
def test_sharding_parallel(self):
137+
"""Test sharding parallel group creation and consistency"""
138+
_default_envs = {
139+
"dp": "1",
140+
"mp": "1",
141+
"pp": "1",
142+
"sep": "1",
143+
"sharding": "2",
144+
"parallel_type": "sharding",
145+
"FLAGS_embedding_deterministic": "1",
146+
"FLAGS_cudnn_deterministic": "1",
147+
}
148+
_changeable_envs = {
149+
"backend": ["gpu"],
150+
}
151+
envs_list = test_base.gen_product_envs_list(
152+
_default_envs, _changeable_envs
153+
)
154+
for envs in envs_list:
155+
self.run_test_case(
156+
"test_process_mesh_group_consistency.py",
157+
user_defined_envs=envs,
158+
)
159+
160+
161+
if __name__ == "__main__":
162+
unittest.main() # python run
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
17+
import paddle.distributed as dist
18+
from paddle.distributed import fleet
19+
20+
21+
class TestProcessMeshGroupConsistency:
22+
def __init__(self):
23+
# Get configuration from environment variables
24+
self.dp = int(os.getenv("dp", "1"))
25+
self.mp = int(os.getenv("mp", "1"))
26+
self.pp = int(os.getenv("pp", "1"))
27+
self.sep = int(os.getenv("sep", "1"))
28+
self.sharding = int(os.getenv("sharding", "1"))
29+
30+
# Determine which parallel type to test
31+
self.parallel_type = os.getenv("parallel_type", "dp")
32+
33+
def init_dist_env(self):
34+
"""Initialize distributed environment"""
35+
# Configure distributed strategy
36+
dist_strategy = fleet.DistributedStrategy()
37+
dist_strategy.hybrid_configs = {
38+
"dp_degree": self.dp,
39+
"mp_degree": self.mp,
40+
"pp_degree": self.pp,
41+
"sep_degree": self.sep,
42+
"sharding_degree": self.sharding,
43+
}
44+
45+
# Add corresponding configuration based on parallel type
46+
if self.sep > 1:
47+
dist_strategy.hybrid_configs["sep_degree"] = self.sep
48+
if self.sharding > 1:
49+
dist_strategy.hybrid_configs["sharding_degree"] = self.sharding
50+
51+
fleet.init(is_collective=True, strategy=dist_strategy)
52+
53+
def test_process_mesh_group_consistency(self):
54+
"""Test consistency between ProcessMesh created groups and HCG created groups"""
55+
56+
# Create corresponding ProcessMesh and get corresponding HCG group based on parallel type
57+
if self.parallel_type == "dp":
58+
mesh = dist.ProcessMesh([0, 1], dim_names=["dp"])
59+
hcg = fleet.get_hybrid_communicate_group()
60+
group = mesh.get_group(dim_name="dp")
61+
hcg_group = hcg.get_data_parallel_group()
62+
63+
elif self.parallel_type == "mp":
64+
mesh = dist.ProcessMesh([0, 1], dim_names=["mp"])
65+
hcg = fleet.get_hybrid_communicate_group()
66+
group = mesh.get_group(dim_name="mp")
67+
hcg_group = hcg.get_model_parallel_group()
68+
69+
elif self.parallel_type == "pp":
70+
mesh = dist.ProcessMesh([0, 1], dim_names=["pp"])
71+
hcg = fleet.get_hybrid_communicate_group()
72+
group = mesh.get_group(dim_name="pp")
73+
hcg_group = hcg.get_pipe_parallel_group()
74+
75+
elif self.parallel_type == "sep":
76+
mesh = dist.ProcessMesh([0, 1], dim_names=["sep"])
77+
hcg = fleet.get_hybrid_communicate_group()
78+
group = mesh.get_group(dim_name="sep")
79+
hcg_group = hcg.get_sep_parallel_group()
80+
81+
elif self.parallel_type == "sharding":
82+
mesh = dist.ProcessMesh([0, 1], dim_names=["sharding"])
83+
hcg = fleet.get_hybrid_communicate_group()
84+
group = mesh.get_group(dim_name="sharding")
85+
hcg_group = hcg.get_sharding_parallel_group()
86+
87+
else:
88+
raise ValueError(f"Unsupported parallel type: {self.parallel_type}")
89+
90+
# Verify that group ranks are consistent
91+
group_ranks = group.ranks
92+
hcg_group_ranks = hcg_group.ranks
93+
assert set(group_ranks) == set(hcg_group_ranks)
94+
95+
# Verify that group IDs are consistent
96+
group_id = group.id
97+
hcg_group_id = hcg_group.id
98+
assert group_id == hcg_group_id
99+
100+
def run_test_cases(self):
101+
"""Run test cases"""
102+
self.init_dist_env()
103+
self.test_process_mesh_group_consistency()
104+
105+
106+
if __name__ == "__main__":
107+
TestProcessMeshGroupConsistency().run_test_cases()

0 commit comments

Comments
 (0)