Skip to content

Commit 4243194

Browse files
authored
fix bug of p2p (#33929)
1 parent a74e01a commit 4243194

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,17 @@
1313
# limitations under the License.
1414

1515
import paddle
16-
import paddle.distributed as dist
1716

1817
_groups = None
1918
_hcg = None
2019

2120

2221
def initialize_p2p_groups(hcg):
2322
global _groups, _hcg
24-
_groups = [dist.new_group(ranks=group) for group in hcg.get_p2p_groups()]
23+
_groups = [
24+
paddle.distributed.new_group(ranks=group)
25+
for group in hcg.get_p2p_groups()
26+
]
2527
_hcg = hcg
2628

2729

@@ -33,7 +35,7 @@ def send(tensor, dest_stage):
3335
_is_valid_communciate(src_stage, dest_stage)
3436
group = _get_send_recv_group(src_stage, dest_stage)
3537
dst_rank = _hcg.get_rank_from_stage(stage_id=dest_stage)
36-
return dist.broadcast(tensor, src_rank, group=group)
38+
return paddle.distributed.broadcast(tensor, src_rank, group=group)
3739

3840

3941
def recv(tensor, src_stage):
@@ -43,7 +45,7 @@ def recv(tensor, src_stage):
4345
_is_valid_communciate(src_stage, dest_stage)
4446
group = _get_send_recv_group(src_stage, dest_stage)
4547
src_rank = _hcg.get_rank_from_stage(stage_id=src_stage)
46-
return dist.broadcast(tensor, src_rank, group=group)
48+
return paddle.distributed.broadcast(tensor, src_rank, group=group)
4749

4850

4951
def _is_valid_communciate(src_stage, dest_stage):

0 commit comments

Comments
 (0)