@@ -221,10 +221,11 @@ def _allreduce_fusion_program(self):
221221
222222 # find all grad params
223223 for idx , op in enumerate (block .ops ):
224+ if first_backward_idx == - 1 and \
225+ is_backward_op (op ):
226+ first_backward_idx = idx
224227 if is_backward_op (op ) and \
225228 OP_ROLE_VAR_KEY in op .attr_names :
226- if first_backward_idx == - 1 :
227- first_backward_idx = idx
228229 op_role_var = op .attr (OP_ROLE_VAR_KEY )
229230 if len (op_role_var ) == 0 :
230231 continue
@@ -239,14 +240,25 @@ def _allreduce_fusion_program(self):
239240 continue
240241 param_grads .append ((param , grad ))
241242
242- # find the index of the op which generates the grad
243- grads_to_idx = {}
244- for param , grad in param_grads :
245- for idx in range (first_backward_idx , len (block .ops )):
246- op = block .ops [idx ]
247- if grad .name in op .output_arg_names :
248- grads_to_idx [grad ] = idx
249- break
243+ # Each item of outputs_name_to_idx is a pair of idx
244+ # The first entry of this pair is the idx of the first op generates the grad
245+ # which is used to indicate the position to insert coalesce op
246+ # The second entry of this pair is the idx of the last op generates the grad
247+ # which is used to indicate teh position to insert sync and allreduce op
248+ outputs_name_to_idx = {}
249+ for idx in range (first_backward_idx , len (block .ops )):
250+ op = block .ops [idx ]
251+ if is_optimizer_op (op ):
252+ break
253+ for name in op .output_arg_names :
254+ var = block .var (name )
255+ if not outputs_name_to_idx .get (var ):
256+ # if the grad only be generated by one op
257+ # the first idx and the last ids are identical
258+ outputs_name_to_idx [var ] = (idx , idx )
259+ else :
260+ outputs_name_to_idx [var ] = (outputs_name_to_idx [var ][0 ],
261+ idx )
250262
251263 # structure of grad_param_segments is
252264 # [([grad0, grad1], [param0, param1]), ([grad2, grad3], [param2, param3])]
@@ -279,8 +291,8 @@ def _allreduce_fusion_program(self):
279291 dtype = grad_segment [0 ].dtype ,
280292 persistable = False ,
281293 stop_gradient = True )
282- before_idx = grads_to_idx [grad_segment [0 ]]
283- after_idx = grads_to_idx [grad_segment [- 1 ]]
294+ before_idx = outputs_name_to_idx [grad_segment [0 ]][ 0 ]
295+ after_idx = outputs_name_to_idx [grad_segment [- 1 ]][ 1 ]
284296 offset = 1
285297 for j in range (i + 1 , len (grad_param_segments )):
286298 # Find the offset of the sync op and allreduce op
@@ -289,7 +301,12 @@ def _allreduce_fusion_program(self):
289301 # the first grad in next segment are from the same op, it means
290302 # a coalesce op has already been inserted before this op.
291303 # Therefore, we have to insert the the sync/allreduce op with offset.
292- if after_idx == grads_to_idx [grad_param_segments [j ][0 ][0 ]]:
304+ # The j is to get the ([grad0, grad1], [param0, param1]) tuple
305+ # The first 0 is to get [grad0, grad1] list
306+ # The second 0 is to get grad0 entry
307+ # The 1 is to get the idx of the last op generates the grad
308+ if after_idx == outputs_name_to_idx [grad_param_segments [j ][0 ][
309+ 0 ]][1 ]:
293310 offset += 1
294311 block ._insert_op_without_sync (
295312 after_idx + offset ,
0 commit comments