Skip to content

Commit cb6510f

Browse files
authored
[hybrid fix] fix pp+dp hang (#34142)
1 parent 7f26453 commit cb6510f

File tree

2 files changed

+25
-26
lines changed

2 files changed

+25
-26
lines changed

python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -434,35 +434,31 @@ def _init_comm(self):
434434

435435
# pp ring
436436
if self.pp_degree > 1:
437+
# TODO (JZ-LIANG) to unify this shit
438+
assert self.pp_rank_ == self.pp_rank, "pp rank for pp opt [{}], pp rank for sharding opt [{}]".format(
439+
self.pp_rank_, self.pp_rank)
440+
437441
for pair in self.pipeline_pair:
438442
pair_key = pair[0] * 1000 + pair[1]
439443
ring_id = self.pp_ring_map[pair_key]
440444
print("pp pair:{}, ring_id: {}".format(pair, ring_id))
441-
if self.pp_rank not in pair: continue
442-
pp_group_endpoints = [
443-
self.pp_group_endpoints[pair[0]],
444-
self.pp_group_endpoints[pair[1]],
445-
]
446-
if pair[0] < pair[1]:
447-
start_ring_id = self.pp_ring_id + pair[1] - pair[0] - 1
448-
else:
449-
start_ring_id = self.pp_ring_id + 2 + pair[0] - pair[1] - 1
450-
pp_rank = 0 if self.pp_rank == pair[0] else 1
451-
self._collective_helper._init_communicator(
452-
self._startup_program,
453-
self.current_endpoint,
454-
pp_group_endpoints,
455-
pp_rank,
456-
ring_id,
457-
False,
458-
global_ring_id=self.global_ring_id,
459-
sync=False)
460-
# append_naive_sync(startup_block, self.startup_prog_sync_var,
461-
# self.global_ring_id)
462-
463-
# TODO (JZ-LIANG) to unify this shit
464-
assert self.pp_rank_ == self.pp_rank, "pp rank for pp opt [{}], pp rank for sharding opt [{}]".format(
465-
self.pp_rank_, self.pp_rank)
445+
if self.pp_rank in pair:
446+
pp_group_endpoints = [
447+
self.pp_group_endpoints[pair[0]],
448+
self.pp_group_endpoints[pair[1]],
449+
]
450+
pp_rank = 0 if self.pp_rank == pair[0] else 1
451+
self._collective_helper._init_communicator(
452+
self._startup_program,
453+
self.current_endpoint,
454+
pp_group_endpoints,
455+
pp_rank,
456+
ring_id,
457+
False,
458+
global_ring_id=self.global_ring_id,
459+
sync=False)
460+
append_naive_sync(startup_block, self.startup_prog_sync_var,
461+
self.global_ring_id)
466462

467463
# pure dp ring
468464
if self.dp_degree > 1:

python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -525,14 +525,17 @@ def test_sharding_with_pp(self):
525525
startup_prog_op_types = [op.type for op in startup_prog_ops]
526526
main_prog_op_types = [op.type for op in main_prog_ops]
527527
print(startup_prog_op_types)
528+
# global, sharding, pp_send, pp_recv
528529
self.assertEqual(startup_prog_op_types, [
529530
'fill_constant', 'uniform_random', 'fill_constant',
530531
'uniform_random', 'fill_constant', 'fill_constant', 'fill_constant',
531532
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
532533
'c_gen_nccl_id', 'c_comm_init', 'fill_constant', 'c_allreduce_sum',
533534
'c_sync_calc_stream', 'c_gen_nccl_id', 'c_comm_init',
534535
'fill_constant', 'c_allreduce_sum', 'c_sync_calc_stream',
535-
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init'
536+
'c_gen_nccl_id', 'c_comm_init', 'fill_constant', 'c_allreduce_sum',
537+
'c_sync_calc_stream', 'c_gen_nccl_id', 'c_comm_init',
538+
'fill_constant', 'c_allreduce_sum', 'c_sync_calc_stream'
536539
])
537540

538541
self.assertEqual(main_prog_op_types, [

0 commit comments

Comments
 (0)