@@ -142,32 +142,103 @@ def prune_gradient_clip(self, block, shard, ring_ids):
142142 return
143143
144144 # TODO (JZ-LIANG) revise this for uniform mixed parallelism
145- def sync_global_norm (self , block , ring_ids ):
145+ def sync_global_norm (self , block , ring_ids , mp_rank ):
146146 """
147147 prune gradient_clip related ops for params that not belong to cur shard
148148 prune: square, reduce_sum, elementwise_mul
149149 keep: sum, sqrt, elementwise_max, elementwise_div
150150 """
151- # FIXME(wangxi): mp should prune duplicated param_grads
151+ is_clip_grad_by_global_norm = False
152+ for idx , op in list (enumerate (block .ops )):
153+ if not self ._is_gradient_clip_op (op ):
154+ continue
155+ if op .type == 'sum' :
156+ is_clip_grad_by_global_norm = True
157+ break
158+ if not is_clip_grad_by_global_norm :
159+ # TODO(Yuang Liu): need some extra handles when clip_grad_norm for mp
160+ return
161+
162+ removed_op_idx = set ()
163+ removed_tmp_var = set ()
164+ for idx , op in list (enumerate (block .ops )):
165+ if not self ._is_gradient_clip_op (op ):
166+ continue
167+ if op .type == 'sum' :
168+ break
169+ for input_name in op .input_arg_names :
170+ input_var = block .var (input_name )
171+ # NOTE: when mp_degree > 1, some vars will be split into each mp rank.
172+ # However, there still some vars such as Scale, Bias are not split.
173+ # Those not be split vars should only be counted once during grad clip
174+ # by global norm. Those vars either doesn't have is_distributed attr
175+ # or the is_distributed attr has been set as False.
176+ # Therefore, we prune those duplicated vars for grad clip.
177+ if mp_rank >= 1 and (not (hasattr (input_var , 'is_distributed' )
178+ and input_var .is_distributed )):
179+ removed_op_idx .add (idx )
180+ for output_name in op .output_arg_names :
181+ removed_tmp_var .add (output_name )
182+
152183 for idx , op in reversed (list (enumerate (block .ops ))):
153184 if not self ._is_gradient_clip_op (op ):
154185 continue
186+ if idx in removed_op_idx :
187+ block ._remove_op (idx , sync = False )
155188
156- if op .type == "sum" :
157- sum_res = op .desc .output_arg_names ()[0 ]
158- for ring_id in ring_ids :
159- if ring_id == - 1 : continue
189+ for var_name in removed_tmp_var :
190+ block ._remove_var (var_name , sync = False )
160191
161- idx = idx + 1
162- block ._insert_op_without_sync (
163- idx ,
164- type = 'c_allreduce_sum' ,
165- inputs = {'X' : sum_res },
166- outputs = {'Out' : sum_res },
167- attrs = {
168- 'ring_id' : ring_id ,
169- 'op_namescope' : "/gradient_clip_model_parallelism" ,
170- 'use_calc_stream' : True ,
171- OP_ROLE_KEY : OpRole .Optimize ,
172- })
173- return
192+ for idx , op in list (enumerate (block .ops )):
193+ if not self ._is_gradient_clip_op (op ):
194+ continue
195+ if op .type == 'sum' :
196+ # If mp_rank == 0, no extra handles, just allreduce
197+ # If mp_rank >= 1, some extra handles is needed
198+ sum_rst_var = block .var (op .output_arg_names [0 ])
199+ if mp_rank >= 1 :
200+ reserved_vars = []
201+ for input_name in op .input_arg_names :
202+ if input_name not in removed_tmp_var :
203+ reserved_vars .append (input_name )
204+
205+ if len (reserved_vars ) > 0 :
206+ op .desc .set_input ("X" , reserved_vars )
207+ else :
208+ # If all input of sum op should be removed, then remove the sum op.
209+ # And set the output's value of sum to 0.
210+ namescope = op .attr ("op_namescope" )
211+ block ._remove_op (idx , sync = False )
212+ fill_constant_op = block ._insert_op_without_sync (
213+ idx ,
214+ type = 'fill_constant' ,
215+ inputs = {},
216+ outputs = {'Out' : sum_rst_var },
217+ attrs = {
218+ 'shape' : sum_rst_var .shape ,
219+ 'dtype' : sum_rst_var .dtype ,
220+ 'value' : 0.0 ,
221+ OP_ROLE_KEY : OpRole .Optimize
222+ })
223+ fill_constant_op ._set_attr ('op_namescope' , namescope )
224+ self ._insert_allreduce (block , ring_ids , idx , sum_rst_var )
225+ break
226+
227+ @staticmethod
228+ def _insert_allreduce (block , ring_ids , idx , var ):
229+ for ring_id in ring_ids :
230+ if ring_id == - 1 :
231+ continue
232+
233+ idx = idx + 1
234+ block ._insert_op_without_sync (
235+ idx ,
236+ type = 'c_allreduce_sum' ,
237+ inputs = {'X' : var },
238+ outputs = {'Out' : var },
239+ attrs = {
240+ 'ring_id' : ring_id ,
241+ 'op_namescope' : "/gradient_clip_model_parallelism" ,
242+ 'use_calc_stream' : True ,
243+ OP_ROLE_KEY : OpRole .Optimize ,
244+ })
0 commit comments