@@ -4795,6 +4795,25 @@ def _insert_sendrecv_ops_for_boundaries(self, block):
47954795
47964796 device_type = cur_device .split (':' )[0 ] + ':'
47974797
4798+ def _check_stage (cur_id , prev_id ):
4799+ # check send/recv stage valid
4800+ is_forward = self ._is_forward_op (op )
4801+ is_backward = self ._is_backward_op (op )
4802+ assert is_forward or is_backward , \
4803+ 'send/recv in pipeline should only be inserted in forward or backward,' \
4804+ 'please check the op_role of op={}' .format (op )
4805+
4806+ if is_forward :
4807+ assert prev_id < cur_id , \
4808+ "In forward, send/recv can only be passed forward, but now " \
4809+ "prev_stage={} great than cur_stage={}, please check op_device of op={}" .format (
4810+ prev_id , cur_id , op )
4811+ elif is_backward :
4812+ assert prev_id > cur_id , \
4813+ "In backward, send/recv can only be passed backward, but now " \
4814+ "prev_stage={} less than cur_stage={}, please check op_device of op={}" .format (
4815+ prev_id , cur_id , op )
4816+
47984817 def _insert_send_recv (cur_id , prev_id ):
47994818 cur_dev = device_type + str (cur_id )
48004819 prev_dev = device_type + str (prev_id )
@@ -4894,6 +4913,8 @@ def _insert_send_recv(cur_id, prev_id):
48944913 extra_index_info ['index' ] += 1
48954914 return
48964915
4916+ _check_stage (cur_id , prev_id )
4917+
48974918 block ._insert_op_without_sync (
48984919 index = index + extra_index_info ['index' ],
48994920 type = 'c_sync_calc_stream' ,
@@ -4978,25 +4999,9 @@ def _insert_send_recv(cur_id, prev_id):
49784999 "Now only 'F-then-B' and '1F1B' are supported."
49795000 "The given value is {}." .format (self .schedule_mode ))
49805001
4981- cur_stage = int (cur_device .split (':' )[1 ])
4982- prev_stage = int (prev_device .split (':' )[1 ])
4983-
4984- is_forward = self ._is_forward_op (op )
4985- is_backward = self ._is_backward_op (op )
4986- assert is_forward or is_backward , \
4987- 'send/recv in pipeline should only be inserted in forward or backward,' \
4988- 'please check the op_role of op={}' .format (op )
4989-
4990- if is_forward :
4991- assert prev_stage < cur_stage , \
4992- "In forward, send/recv can only be passed forward, but now " \
4993- "prev_stage={} great than cur_stage={}, please check op_device of op={}" .format (prev_stage , cur_stage , op )
4994- elif is_backward :
4995- assert prev_stage > cur_stage , \
4996- "In backward, send/recv can only be passed backward, but now " \
4997- "prev_stage={} less than cur_stage={}, please check op_device of op={}" .format (prev_stage , cur_stage , op )
4998-
4999- _insert_send_recv (cur_stage , prev_stage )
5002+ _insert_send_recv (
5003+ int (cur_device .split (':' )[1 ]),
5004+ int (prev_device .split (':' )[1 ]))
50005005 block ._sync_with_cpp ()
50015006
50025007 def _insert_loss_scale (self , block ):
0 commit comments