Skip to content

Commit 1b12a43

Browse files
committed
resolve the new logic problem, test=allcase
1 parent bbfc314 commit 1b12a43

File tree

1 file changed

+30
-13
lines changed

1 file changed

+30
-13
lines changed

python/paddle/distributed/fleet/meta_optimizers/raw_program_optimizer.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -221,10 +221,11 @@ def _allreduce_fusion_program(self):
221221

222222
# find all grad params
223223
for idx, op in enumerate(block.ops):
224+
if first_backward_idx == -1 and \
225+
is_backward_op(op):
226+
first_backward_idx = idx
224227
if is_backward_op(op) and \
225228
OP_ROLE_VAR_KEY in op.attr_names:
226-
if first_backward_idx == -1:
227-
first_backward_idx = idx
228229
op_role_var = op.attr(OP_ROLE_VAR_KEY)
229230
if len(op_role_var) == 0:
230231
continue
@@ -239,14 +240,25 @@ def _allreduce_fusion_program(self):
239240
continue
240241
param_grads.append((param, grad))
241242

242-
# find the index of the op which generates the grad
243-
grads_to_idx = {}
244-
for param, grad in param_grads:
245-
for idx in range(first_backward_idx, len(block.ops)):
246-
op = block.ops[idx]
247-
if grad.name in op.output_arg_names:
248-
grads_to_idx[grad] = idx
249-
break
243+
# Each item of outputs_name_to_idx is a pair of idx
244+
# The first entry of this pair is the idx of the first op generates the grad
245+
# which is used to indicate the position to insert coalesce op
246+
# The second entry of this pair is the idx of the last op generates the grad
247+
# which is used to indicate teh position to insert sync and allreduce op
248+
outputs_name_to_idx = {}
249+
for idx in range(first_backward_idx, len(block.ops)):
250+
op = block.ops[idx]
251+
if is_optimizer_op(op):
252+
break
253+
for name in op.output_arg_names:
254+
var = block.var(name)
255+
if not outputs_name_to_idx.get(var):
256+
# if the grad only be generated by one op
257+
# the first idx and the last ids are identical
258+
outputs_name_to_idx[var] = (idx, idx)
259+
else:
260+
outputs_name_to_idx[var] = (outputs_name_to_idx[var][0],
261+
idx)
250262

251263
# structure of grad_param_segments is
252264
# [([grad0, grad1], [param0, param1]), ([grad2, grad3], [param2, param3])]
@@ -279,8 +291,8 @@ def _allreduce_fusion_program(self):
279291
dtype=grad_segment[0].dtype,
280292
persistable=False,
281293
stop_gradient=True)
282-
before_idx = grads_to_idx[grad_segment[0]]
283-
after_idx = grads_to_idx[grad_segment[-1]]
294+
before_idx = outputs_name_to_idx[grad_segment[0]][0]
295+
after_idx = outputs_name_to_idx[grad_segment[-1]][1]
284296
offset = 1
285297
for j in range(i + 1, len(grad_param_segments)):
286298
# Find the offset of the sync op and allreduce op
@@ -289,7 +301,12 @@ def _allreduce_fusion_program(self):
289301
# the first grad in next segment are from the same op, it means
290302
# a coalesce op has already been inserted before this op.
291303
# Therefore, we have to insert the the sync/allreduce op with offset.
292-
if after_idx == grads_to_idx[grad_param_segments[j][0][0]]:
304+
# The j is to get the ([grad0, grad1], [param0, param1]) tuple
305+
# The first 0 is to get [grad0, grad1] list
306+
# The second 0 is to get grad0 entry
307+
# The 1 is to get the idx of the last op generates the grad
308+
if after_idx == outputs_name_to_idx[grad_param_segments[j][0][
309+
0]][1]:
293310
offset += 1
294311
block._insert_op_without_sync(
295312
after_idx + offset,

0 commit comments

Comments
 (0)