Skip to content

Commit 00404fb

Browse files
committed
fix recompute check
1 parent 101f2f9 commit 00404fb

File tree

1 file changed

+24
-19
lines changed

1 file changed

+24
-19
lines changed

python/paddle/fluid/optimizer.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4795,6 +4795,25 @@ def _insert_sendrecv_ops_for_boundaries(self, block):
47954795

47964796
device_type = cur_device.split(':')[0] + ':'
47974797

4798+
def _check_stage(cur_id, prev_id):
4799+
# check send/recv stage valid
4800+
is_forward = self._is_forward_op(op)
4801+
is_backward = self._is_backward_op(op)
4802+
assert is_forward or is_backward, \
4803+
'send/recv in pipeline should only be inserted in forward or backward,' \
4804+
'please check the op_role of op={}'.format(op)
4805+
4806+
if is_forward:
4807+
assert prev_id < cur_id, \
4808+
"In forward, send/recv can only be passed forward, but now " \
4809+
"prev_stage={} great than cur_stage={}, please check op_device of op={}".format(
4810+
prev_id, cur_id, op)
4811+
elif is_backward:
4812+
assert prev_id > cur_id, \
4813+
"In backward, send/recv can only be passed backward, but now " \
4814+
"prev_stage={} less than cur_stage={}, please check op_device of op={}".format(
4815+
prev_id, cur_id, op)
4816+
47984817
def _insert_send_recv(cur_id, prev_id):
47994818
cur_dev = device_type + str(cur_id)
48004819
prev_dev = device_type + str(prev_id)
@@ -4894,6 +4913,8 @@ def _insert_send_recv(cur_id, prev_id):
48944913
extra_index_info['index'] += 1
48954914
return
48964915

4916+
_check_stage(cur_id, prev_id)
4917+
48974918
block._insert_op_without_sync(
48984919
index=index + extra_index_info['index'],
48994920
type='c_sync_calc_stream',
@@ -4978,25 +4999,9 @@ def _insert_send_recv(cur_id, prev_id):
49784999
"Now only 'F-then-B' and '1F1B' are supported."
49795000
"The given value is {}.".format(self.schedule_mode))
49805001

4981-
cur_stage = int(cur_device.split(':')[1])
4982-
prev_stage = int(prev_device.split(':')[1])
4983-
4984-
is_forward = self._is_forward_op(op)
4985-
is_backward = self._is_backward_op(op)
4986-
assert is_forward or is_backward, \
4987-
'send/recv in pipeline should only be inserted in forward or backward,' \
4988-
'please check the op_role of op={}'.format(op)
4989-
4990-
if is_forward:
4991-
assert prev_stage < cur_stage, \
4992-
"In forward, send/recv can only be passed forward, but now " \
4993-
"prev_stage={} great than cur_stage={}, please check op_device of op={}".format(prev_stage, cur_stage, op)
4994-
elif is_backward:
4995-
assert prev_stage > cur_stage, \
4996-
"In backward, send/recv can only be passed backward, but now " \
4997-
"prev_stage={} less than cur_stage={}, please check op_device of op={}".format(prev_stage, cur_stage, op)
4998-
4999-
_insert_send_recv(cur_stage, prev_stage)
5002+
_insert_send_recv(
5003+
int(cur_device.split(':')[1]),
5004+
int(prev_device.split(':')[1]))
50005005
block._sync_with_cpp()
50015006

50025007
def _insert_loss_scale(self, block):

0 commit comments

Comments
 (0)