Skip to content

Commit abbae2e

Browse files
committed
bugs fix for before_idx
1 parent 1b12a43 commit abbae2e

File tree

1 file changed

+52
-37
lines changed

1 file changed

+52
-37
lines changed

python/paddle/distributed/fleet/meta_optimizers/raw_program_optimizer.py

Lines changed: 52 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -240,25 +240,8 @@ def _allreduce_fusion_program(self):
240240
continue
241241
param_grads.append((param, grad))
242242

243-
# Each item of outputs_name_to_idx is a pair of idx
244-
# The first entry of this pair is the idx of the first op generates the grad
245-
# which is used to indicate the position to insert coalesce op
246-
# The second entry of this pair is the idx of the last op generates the grad
247-
# which is used to indicate teh position to insert sync and allreduce op
248-
outputs_name_to_idx = {}
249-
for idx in range(first_backward_idx, len(block.ops)):
250-
op = block.ops[idx]
251-
if is_optimizer_op(op):
252-
break
253-
for name in op.output_arg_names:
254-
var = block.var(name)
255-
if not outputs_name_to_idx.get(var):
256-
# if the grad only be generated by one op
257-
# the first idx and the last ids are identical
258-
outputs_name_to_idx[var] = (idx, idx)
259-
else:
260-
outputs_name_to_idx[var] = (outputs_name_to_idx[var][0],
261-
idx)
243+
outputs_name_to_idx = self.__get_ouputs_name_to_idx(first_backward_idx,
244+
block)
262245

263246
# structure of grad_param_segments is
264247
# [([grad0, grad1], [param0, param1]), ([grad2, grad3], [param2, param3])]
@@ -280,6 +263,7 @@ def _allreduce_fusion_program(self):
280263
if len(grad_param_segments) == 0:
281264
return
282265

266+
fused_vars = [None] * len(grad_param_segments)
283267
for i in range(len(grad_param_segments) - 1, -1, -1):
284268
# travers the grad_param_segments in backward
285269
# not to use reversed since needs the absolute index value
@@ -291,25 +275,10 @@ def _allreduce_fusion_program(self):
291275
dtype=grad_segment[0].dtype,
292276
persistable=False,
293277
stop_gradient=True)
294-
before_idx = outputs_name_to_idx[grad_segment[0]][0]
278+
fused_vars[i] = fused_var
295279
after_idx = outputs_name_to_idx[grad_segment[-1]][1]
296-
offset = 1
297-
for j in range(i + 1, len(grad_param_segments)):
298-
# Find the offset of the sync op and allreduce op
299-
# Some ops may have multi grad_param pairs, and these grads might be
300-
# split into different segments. If the last grad in this segment and
301-
# the first grad in next segment are from the same op, it means
302-
# a coalesce op has already been inserted before this op.
303-
# Therefore, we have to insert the the sync/allreduce op with offset.
304-
# The j is to get the ([grad0, grad1], [param0, param1]) tuple
305-
# The first 0 is to get [grad0, grad1] list
306-
# The second 0 is to get grad0 entry
307-
# The 1 is to get the idx of the last op generates the grad
308-
if after_idx == outputs_name_to_idx[grad_param_segments[j][0][
309-
0]][1]:
310-
offset += 1
311280
block._insert_op_without_sync(
312-
after_idx + offset,
281+
after_idx + 1,
313282
type='c_allreduce_sum',
314283
inputs={'X': fused_var},
315284
outputs={'Out': fused_var},
@@ -320,11 +289,35 @@ def _allreduce_fusion_program(self):
320289
})
321290
if not self.calc_comm_same_stream:
322291
block._insert_op_without_sync(
323-
after_idx + offset,
292+
after_idx + 1,
324293
type='c_sync_calc_stream',
325294
inputs={'X': fused_var},
326295
outputs={'Out': fused_var},
327296
attrs={OP_ROLE_KEY: OpRole.Backward})
297+
298+
# update the outputs_name_to_idx after insertion of sync/allreduce ops
299+
outputs_name_to_idx = self.__get_ouputs_name_to_idx(first_backward_idx,
300+
block)
301+
# the before_idx is not guaranteed sorted, therefore we have to find the
302+
# topology to insert the coalesce ops
303+
pos_for_coalesce = {}
304+
for i in range(len(grad_param_segments) - 1, -1, -1):
305+
# We separate the insertion of coalesce op and the insertion of sync/allreduce op,
306+
# since that the coalesce op's insertion may invalidate the outputs_name_to_idx
307+
grad_segment, param_segment = grad_param_segments[i]
308+
before_idx = len(block.ops)
309+
for grad in outputs_name_to_idx:
310+
before_idx = min(before_idx, outputs_name_to_idx[grad][0])
311+
pos_for_coalesce[i] = before_idx
312+
313+
# insert the coalesce op based on the sorted before_idx
314+
pos_for_coalesce = sorted(
315+
pos_for_coalesce.items(),
316+
key=lambda kv: (kv[1], kv[0]),
317+
reverse=True)
318+
for i, before_idx in pos_for_coalesce:
319+
grad_segment, param_segment = grad_param_segments[i]
320+
fused_var = fused_vars[i]
328321
block._insert_op_without_sync(
329322
before_idx,
330323
type="coalesce_tensor",
@@ -354,3 +347,25 @@ def _allreduce_fusion_program(self):
354347
OP_ROLE_KEY: OpRole.Backward})
355348
break
356349
block._sync_with_cpp()
350+
351+
def __get_ouputs_name_to_idx(self, first_backward_idx, block):
352+
# Each item of outputs_name_to_idx is a pair of idx.
353+
# The first entry of this pair is the idx of the first op generates the grad,
354+
# which is used to indicate the position to insert coalesce op.
355+
# The second entry of this pair is the idx of the last op generates the grad,
356+
# which is used to indicate the position to insert sync and allreduce op.
357+
outputs_name_to_idx = {}
358+
for idx in range(first_backward_idx, len(block.ops)):
359+
op = block.ops[idx]
360+
if is_optimizer_op(op):
361+
break
362+
for name in op.output_arg_names:
363+
var = block.var(name)
364+
if not outputs_name_to_idx.get(var):
365+
# if the grad only be generated by one op
366+
# the first idx and the last ids are identical
367+
outputs_name_to_idx[var] = (idx, idx)
368+
else:
369+
outputs_name_to_idx[var] = (outputs_name_to_idx[var][0],
370+
idx)
371+
return outputs_name_to_idx

0 commit comments

Comments
 (0)