Skip to content

Commit 1e9299a

Browse files
authored
Fix hang of hybrid parallel in new_group (#33141)
* fix hang of hybrid parallel * fix new_group for hang problem
1 parent d523dff commit 1e9299a

1 file changed

Lines changed: 30 additions & 24 deletions

File tree

python/paddle/distributed/collective.py

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -239,31 +239,37 @@ def new_group(ranks=None, backend=None):
239239
if global_rank not in ranks:
240240
gp = Group(-1, -1, ring_id, ranks)
241241
_group_map[ring_id] = gp
242-
return gp
243-
244-
ranks = sorted(ranks)
245-
group_rank = ranks.index(global_rank)
246-
group_size = len(ranks)
247-
gp = Group(group_rank, group_size, ring_id, ranks)
248-
_group_map[ring_id] = gp
249-
250-
if group_size < 2:
251-
return gp
252-
253-
strategy = core.ParallelStrategy()
254-
strategy.nranks = group_size
255-
strategy.local_rank = group_rank
256-
strategy.trainer_endpoints = [genv.trainer_endpoints[i] for i in ranks]
257-
strategy.current_endpoint = genv.current_endpoint
258-
strategy.nrings = 1
259-
260-
if core.is_compiled_with_cuda():
261-
place = core.CUDAPlace(genv.device_id)
262-
core.NCCLParallelContext(strategy, place).init_with_ring_id(ring_id)
263242
else:
264-
assert False, ("no cuda device found")
265-
# need to barrier to construct group
266-
barrier(gp)
243+
ranks = sorted(ranks)
244+
group_rank = ranks.index(global_rank)
245+
group_size = len(ranks)
246+
gp = Group(group_rank, group_size, ring_id, ranks)
247+
_group_map[ring_id] = gp
248+
249+
if group_size >= 2:
250+
strategy = core.ParallelStrategy()
251+
strategy.nranks = group_size
252+
strategy.local_rank = group_rank
253+
strategy.trainer_endpoints = [
254+
genv.trainer_endpoints[i] for i in ranks
255+
]
256+
strategy.current_endpoint = genv.current_endpoint
257+
strategy.nrings = 1
258+
259+
if core.is_compiled_with_cuda():
260+
place = core.CUDAPlace(genv.device_id)
261+
core.NCCLParallelContext(strategy,
262+
place).init_with_ring_id(ring_id)
263+
else:
264+
assert False, ("no cuda device found")
265+
else:
266+
return gp
267+
268+
# TODO(shenliang03): This is a temporary solution to solve the problem of
269+
# hang caused by cross-creation of new_group
270+
tmp = fill_constant([0], dtype="int32", value="1")
271+
paddle.distributed.all_reduce(tmp, use_calc_stream=True)
272+
paddle.distributed.wait(tmp)
267273
return gp
268274

269275

0 commit comments

Comments
 (0)