@@ -239,31 +239,37 @@ def new_group(ranks=None, backend=None):
239239 if global_rank not in ranks :
240240 gp = Group (- 1 , - 1 , ring_id , ranks )
241241 _group_map [ring_id ] = gp
242- return gp
243-
244- ranks = sorted (ranks )
245- group_rank = ranks .index (global_rank )
246- group_size = len (ranks )
247- gp = Group (group_rank , group_size , ring_id , ranks )
248- _group_map [ring_id ] = gp
249-
250- if group_size < 2 :
251- return gp
252-
253- strategy = core .ParallelStrategy ()
254- strategy .nranks = group_size
255- strategy .local_rank = group_rank
256- strategy .trainer_endpoints = [genv .trainer_endpoints [i ] for i in ranks ]
257- strategy .current_endpoint = genv .current_endpoint
258- strategy .nrings = 1
259-
260- if core .is_compiled_with_cuda ():
261- place = core .CUDAPlace (genv .device_id )
262- core .NCCLParallelContext (strategy , place ).init_with_ring_id (ring_id )
263242 else :
264- assert False , ("no cuda device found" )
265- # need to barrier to construct group
266- barrier (gp )
243+ ranks = sorted (ranks )
244+ group_rank = ranks .index (global_rank )
245+ group_size = len (ranks )
246+ gp = Group (group_rank , group_size , ring_id , ranks )
247+ _group_map [ring_id ] = gp
248+
249+ if group_size >= 2 :
250+ strategy = core .ParallelStrategy ()
251+ strategy .nranks = group_size
252+ strategy .local_rank = group_rank
253+ strategy .trainer_endpoints = [
254+ genv .trainer_endpoints [i ] for i in ranks
255+ ]
256+ strategy .current_endpoint = genv .current_endpoint
257+ strategy .nrings = 1
258+
259+ if core .is_compiled_with_cuda ():
260+ place = core .CUDAPlace (genv .device_id )
261+ core .NCCLParallelContext (strategy ,
262+ place ).init_with_ring_id (ring_id )
263+ else :
264+ assert False , ("no cuda device found" )
265+ else :
266+ return gp
267+
268+ # TODO(shenliang03): This is a temporary solution to solve the problem of
269+ # hang caused by cross-creation of new_group
270+ tmp = fill_constant ([0 ], dtype = "int32" , value = "1" )
271+ paddle .distributed .all_reduce (tmp , use_calc_stream = True )
272+ paddle .distributed .wait (tmp )
267273 return gp
268274
269275
0 commit comments