@@ -240,25 +240,8 @@ def _allreduce_fusion_program(self):
240240 continue
241241 param_grads .append ((param , grad ))
242242
243- # Each item of outputs_name_to_idx is a pair of idx
244- # The first entry of this pair is the idx of the first op generates the grad
245- # which is used to indicate the position to insert coalesce op
246- # The second entry of this pair is the idx of the last op generates the grad
247- # which is used to indicate teh position to insert sync and allreduce op
248- outputs_name_to_idx = {}
249- for idx in range (first_backward_idx , len (block .ops )):
250- op = block .ops [idx ]
251- if is_optimizer_op (op ):
252- break
253- for name in op .output_arg_names :
254- var = block .var (name )
255- if not outputs_name_to_idx .get (var ):
256- # if the grad only be generated by one op
257- # the first idx and the last ids are identical
258- outputs_name_to_idx [var ] = (idx , idx )
259- else :
260- outputs_name_to_idx [var ] = (outputs_name_to_idx [var ][0 ],
261- idx )
243+ outputs_name_to_idx = self .__get_ouputs_name_to_idx (first_backward_idx ,
244+ block )
262245
263246 # structure of grad_param_segments is
264247 # [([grad0, grad1], [param0, param1]), ([grad2, grad3], [param2, param3])]
@@ -280,6 +263,7 @@ def _allreduce_fusion_program(self):
280263 if len (grad_param_segments ) == 0 :
281264 return
282265
266+ fused_vars = [None ] * len (grad_param_segments )
283267 for i in range (len (grad_param_segments ) - 1 , - 1 , - 1 ):
284268 # travers the grad_param_segments in backward
285269 # not to use reversed since needs the absolute index value
@@ -291,25 +275,10 @@ def _allreduce_fusion_program(self):
291275 dtype = grad_segment [0 ].dtype ,
292276 persistable = False ,
293277 stop_gradient = True )
294- before_idx = outputs_name_to_idx [ grad_segment [ 0 ]][ 0 ]
278+ fused_vars [ i ] = fused_var
295279 after_idx = outputs_name_to_idx [grad_segment [- 1 ]][1 ]
296- offset = 1
297- for j in range (i + 1 , len (grad_param_segments )):
298- # Find the offset of the sync op and allreduce op
299- # Some ops may have multi grad_param pairs, and these grads might be
300- # split into different segments. If the last grad in this segment and
301- # the first grad in next segment are from the same op, it means
302- # a coalesce op has already been inserted before this op.
303- # Therefore, we have to insert the the sync/allreduce op with offset.
304- # The j is to get the ([grad0, grad1], [param0, param1]) tuple
305- # The first 0 is to get [grad0, grad1] list
306- # The second 0 is to get grad0 entry
307- # The 1 is to get the idx of the last op generates the grad
308- if after_idx == outputs_name_to_idx [grad_param_segments [j ][0 ][
309- 0 ]][1 ]:
310- offset += 1
311280 block ._insert_op_without_sync (
312- after_idx + offset ,
281+ after_idx + 1 ,
313282 type = 'c_allreduce_sum' ,
314283 inputs = {'X' : fused_var },
315284 outputs = {'Out' : fused_var },
@@ -320,11 +289,35 @@ def _allreduce_fusion_program(self):
320289 })
321290 if not self .calc_comm_same_stream :
322291 block ._insert_op_without_sync (
323- after_idx + offset ,
292+ after_idx + 1 ,
324293 type = 'c_sync_calc_stream' ,
325294 inputs = {'X' : fused_var },
326295 outputs = {'Out' : fused_var },
327296 attrs = {OP_ROLE_KEY : OpRole .Backward })
297+
298+ # update the outputs_name_to_idx after insertion of sync/allreduce ops
299+ outputs_name_to_idx = self .__get_ouputs_name_to_idx (first_backward_idx ,
300+ block )
301+ # the before_idx is not guaranteed sorted, therefore we have to find the
302+ # topology to insert the coalesce ops
303+ pos_for_coalesce = {}
304+ for i in range (len (grad_param_segments ) - 1 , - 1 , - 1 ):
305+ # We separate the insertion of coalesce op and the insertion of sync/allreduce op,
306+ # since that the coalesce op's insertion may invalidate the outputs_name_to_idx
307+ grad_segment , param_segment = grad_param_segments [i ]
308+ before_idx = len (block .ops )
309+ for grad in outputs_name_to_idx :
310+ before_idx = min (before_idx , outputs_name_to_idx [grad ][0 ])
311+ pos_for_coalesce [i ] = before_idx
312+
313+ # insert the coalesce op based on the sorted before_idx
314+ pos_for_coalesce = sorted (
315+ pos_for_coalesce .items (),
316+ key = lambda kv : (kv [1 ], kv [0 ]),
317+ reverse = True )
318+ for i , before_idx in pos_for_coalesce :
319+ grad_segment , param_segment = grad_param_segments [i ]
320+ fused_var = fused_vars [i ]
328321 block ._insert_op_without_sync (
329322 before_idx ,
330323 type = "coalesce_tensor" ,
@@ -354,3 +347,25 @@ def _allreduce_fusion_program(self):
354347 OP_ROLE_KEY : OpRole .Backward })
355348 break
356349 block ._sync_with_cpp ()
350+
351+ def __get_ouputs_name_to_idx (self , first_backward_idx , block ):
352+ # Each item of outputs_name_to_idx is a pair of idx.
353+ # The first entry of this pair is the idx of the first op generates the grad,
354+ # which is used to indicate the position to insert coalesce op.
355+ # The second entry of this pair is the idx of the last op generates the grad,
356+ # which is used to indicate the position to insert sync and allreduce op.
357+ outputs_name_to_idx = {}
358+ for idx in range (first_backward_idx , len (block .ops )):
359+ op = block .ops [idx ]
360+ if is_optimizer_op (op ):
361+ break
362+ for name in op .output_arg_names :
363+ var = block .var (name )
364+ if not outputs_name_to_idx .get (var ):
365+ # if the grad only be generated by one op
366+ # the first idx and the last ids are identical
367+ outputs_name_to_idx [var ] = (idx , idx )
368+ else :
369+ outputs_name_to_idx [var ] = (outputs_name_to_idx [var ][0 ],
370+ idx )
371+ return outputs_name_to_idx
0 commit comments