Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions python/paddle/fluid/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4663,6 +4663,7 @@ def _check_validation(self, block):
pre_stage_id = None
decrease_flag = False
in_optimize = False
in_forward = True
for op in block.ops:
if not op._has_kernel(op.type):
assert op.type == "conditional_block" and (
Expand All @@ -4680,6 +4681,8 @@ def _check_validation(self, block):
valid_op_role_value)
if int(op_role) == int(self._op_role.Optimize):
in_optimize = True
if int(op_role) == int(self._op_role.Backward):
in_forward = False

assert op.has_attr(self._op_device_key), (
"op ({}) has no {} attribute.".format(op.type,
Expand Down Expand Up @@ -4707,14 +4710,16 @@ def _check_validation(self, block):
"but the interval of op={} and prev op is {}".format(op, interval)
# stage must be in order, such as Forward(0 1 2 3 4), Backward(4 3 2 1 0)
# if stage is unordered, such as Forward(0 1 2 3 4 3 4), will report error
if interval == -1:
decrease_flag = True
if interval == 1:
# FIXME(wangxi): recompute failed
if in_forward:
assert interval >= 0, \
"Pipeline stage must be sequential increment in Forward, prev_stage={}, " \
"please check the stage of op={}".format(pre_stage_id, op)
else:
# FIXME(wangxi): recompute check failed
pass
#assert decrease_flag is False, \
# "Pipeline stage must be in order, " \
# "please check the stage of op={}".format(op)
#assert interval <=0, \
# "Pipeline stage must be sequential decrement in Backward, prev_stage={}, " \
# "please check the stage of op={}".format(pre_stage_id, op)
pre_stage_id = stage_id

return device_list
Expand Down