@@ -4221,6 +4221,8 @@ def __init__(self, optimizer, num_microbatches=1, start_cpu_core_id=0):
42214221 self ._param_device_map = None
42224222 self ._pipeline_pair = []
42234223 self ._pp_ring_map = dict ()
4224+ self .output_var_to_op = None
4225+ self .input_var_to_op = None
42244226
42254227 # insert allreduce op to sync global information for global
42264228 # gradient clip and amp
@@ -4657,6 +4659,9 @@ def _check_validation(self, block):
46574659 int (self ._op_role .Optimize ),
46584660 int (self ._op_role .Backward ) | int (self ._op_role .Loss ),
46594661 ]
4662+ pre_stage_id = None
4663+ decrease_flag = False
4664+ in_optimize = False
46604665 for op in block .ops :
46614666 if not op ._has_kernel (op .type ):
46624667 assert op .type == "conditional_block" and (
@@ -4666,25 +4671,49 @@ def _check_validation(self, block):
46664671 assert op .has_attr (self ._op_role_key ), (
46674672 "op ({}) has no {} attribute." .format (op .type ,
46684673 self ._op_role_key ))
4669- assert int (op .attr (self ._op_role_key )) in valid_op_role_value , \
4674+ op_role = op .attr (self ._op_role_key )
4675+ assert int (op_role ) in valid_op_role_value , \
46704676 "op_role {} for op {} must be one of {}" .format (
4671- op . attr ( self . _op_role_key ) ,
4677+ op_role ,
46724678 op .type ,
46734679 valid_op_role_value )
4680+ if int (op_role ) == int (self ._op_role .Optimize ):
4681+ in_optimize = True
4682+
46744683 assert op .has_attr (self ._op_device_key ), (
46754684 "op ({}) has no {} attribute." .format (op .type ,
46764685 self ._op_device_key ))
46774686
46784687 device = op .attr (self ._op_device_key )
46794688 assert device , ("op_device attribute for op "
46804689 "{} has not been set." .format (op .type ))
4681- if device == "gpu:all" : continue
4690+ if device == "gpu:all" or device == "npu:all" : continue
4691+
46824692 dev_type = device .split (':' )[0 ]
4693+ stage_id = int (device .split (':' )[1 ])
46834694 assert dev_type == "gpu" or dev_type == 'npu' , (
46844695 "Now only gpu and npu devices are supported "
46854696 "for pipeline parallelism." )
4686- if not device in device_list :
4697+
4698+ if device not in device_list :
46874699 device_list .append (device )
4700+
4701+ if not in_optimize :
4702+ if pre_stage_id is not None :
4703+ interval = stage_id - pre_stage_id
4704+ assert abs (interval ) <= 1 , \
4705+ "The stage interval of two consecutive ops in the pipeline must be < = 1," \
4706+ "but the interval of op={} and prev op is {}" .format (op , interval )
4707+ # stage must be in order, such as Forward(0 1 2 3 4), Backward(4 3 2 1 0)
4708+ # if stage is unordered, such as Forward(0 1 2 3 4 3 4), will report error
4709+ if interval == - 1 :
4710+ decrease_flag = True
4711+ if interval == 1 :
4712+ assert decrease_flag is False , \
4713+ "Pipeline stage must be in order, " \
4714+ "please check the stage of op={}" .format (op )
4715+ pre_stage_id = stage_id
4716+
46884717 return device_list
46894718
46904719 def _insert_sendrecv_ops_for_boundaries (self , block ):
@@ -4826,14 +4855,16 @@ def _insert_send_recv(cur_id, prev_id):
48264855 })
48274856 extra_index_info ['index' ] += 1
48284857 insert_index = None
4858+
48294859 if int (op_role ) == int (self ._op_role .Backward ):
48304860 insert_index = extra_index_info [
48314861 'first_optimize_index' ]
48324862 new_op_role = self ._op_role .Optimize
48334863 else :
48344864 insert_index = index
48354865 new_op_role = self ._op_role .Backward
4836- block ._insert_op_without_sync (
4866+
4867+ sync_comm_op = block ._insert_op_without_sync (
48374868 index = insert_index + extra_index_info ['index' ],
48384869 type = 'c_sync_comm_stream' ,
48394870 inputs = {'X' : [var ]},
@@ -4843,8 +4874,11 @@ def _insert_send_recv(cur_id, prev_id):
48434874 self ._op_role_key : new_op_role ,
48444875 'ring_id' : ring_id ,
48454876 })
4877+
48464878 if int (op_role ) == int (self ._op_role .Forward ):
4879+ sync_comm_op ._set_attr ('pipeline_flag' , '' )
48474880 extra_index_info ['index' ] += 1
4881+
48484882 var_shape = list (var .shape )
48494883 var_shape [0 ] = self .micro_batch_size if var_shape [
48504884 0 ] < 0 else var_shape [0 ]
@@ -5153,17 +5187,55 @@ def _get_input_output_info(self, block):
51535187 Get info of op input and output.
51545188 '''
51555189 # A map from output var to op which generate it.
5156- self . output_var_to_op = dict ( )
5190+ output_var_to_op = defaultdict ( list )
51575191 # A map from var to op which takes it as input.
5158- self . input_var_to_op = dict ( )
5192+ input_var_to_op = defaultdict ( list )
51595193
5160- for index , op in enumerate (list ( block .ops ) ):
5194+ for index , op in enumerate (block .ops ):
51615195 for var_name in op .input_arg_names :
5162- ops = self .input_var_to_op .setdefault (var_name , [])
5163- ops .append ([op , index ])
5196+ input_var_to_op [var_name ].append ([op , index ])
51645197 for var_name in op .output_arg_names :
5165- ops = self .output_var_to_op .setdefault (var_name , [])
5166- ops .append ([op , index ])
5198+ output_var_to_op [var_name ].append ([op , index ])
5199+
5200+ return output_var_to_op , input_var_to_op
5201+
5202+ def _optimize_forward_send_sync (self , program ):
5203+ """
5204+ optimize forward send's sync_comm_stream schedule
5205+ """
5206+ if self .schedule_mode != '1F1B' : return
5207+
5208+ block = program .block (0 )
5209+
5210+ backward_recv_index = None
5211+ for index , op in enumerate (block .ops ):
5212+ if op .type == 'recv_v2' and self ._is_backward_op (op ):
5213+ backward_recv_index = index
5214+ break
5215+
5216+ if backward_recv_index is None : return
5217+
5218+ offset = 0
5219+ for index , op in enumerate (list (block .ops )):
5220+ if index >= backward_recv_index : break
5221+ if op .type == 'c_sync_comm_stream' and op .has_attr ('pipeline_flag' ):
5222+ var_name = op .input_arg_names [0 ]
5223+ var = block .var (var_name )
5224+ block ._remove_op (index + offset , sync = False )
5225+ offset -= 1
5226+ # NOTE:
5227+ # 1. When the backward recv is completed, it indicates
5228+ # that the forward send is completed too. So we only need
5229+ # to use the NOP op to prevent memory release.
5230+ # 2. Because we removed sync_comm_op,
5231+ # we will insert NOP after recv_op.
5232+ block ._insert_op_without_sync (
5233+ index = backward_recv_index ,
5234+ type = 'nop' ,
5235+ inputs = {'X' : [var ]},
5236+ outputs = {'Out' : [var ]},
5237+ attrs = {self ._op_role_key : self ._op_role .Backward })
5238+ block ._sync_with_cpp ()
51675239
51685240 def minimize (self ,
51695241 loss ,
@@ -5200,7 +5272,8 @@ def minimize(self,
52005272 loss , startup_program , parameter_list , no_grad_set )
52015273 self ._param_device_map = self ._origin_optimizer ._param_device_map
52025274
5203- self ._get_input_output_info (main_block )
5275+ self .output_var_to_op , self .input_var_to_op = \
5276+ self ._get_input_output_info (main_block )
52045277 # Step1: add default op_device attribute for ops.
52055278 self ._add_op_device_attr (main_block )
52065279 device_list = self ._check_validation (main_block )
@@ -5229,6 +5302,10 @@ def device_cmp(device1, device2):
52295302 for p in program_list :
52305303 self ._create_vars (p .global_block (), main_block )
52315304
5305+ self .local_rank %= len (device_list )
5306+ # Step3.5: optimize forward send sync_comm to overlap send and recv
5307+ self ._optimize_forward_send_sync (program_list [self .local_rank ])
5308+
52325309 # Step4: Special Case: process persistable vars that exist in
52335310 # multiple sections
52345311 # FIXME
@@ -5238,7 +5315,6 @@ def device_cmp(device1, device2):
52385315 # Step5: Add sub blocks for section programs
52395316 self ._add_sub_blocks (main_block , program_list )
52405317
5241- self .local_rank %= len (device_list )
52425318 place_list = []
52435319 for dev in device_list :
52445320 dev_index = int (dev .split (":" )[1 ])
0 commit comments