@@ -162,6 +162,7 @@ def __init__(self,
162162 self ._params_name = set ()
163163 self ._apply_decay_param_fun = apply_decay_param_fun
164164 self ._coeff = coeff
165+ self ._lr_to_coeff = dict ()
165166
166167 super (AdamW , self ).__init__ (
167168 learning_rate = learning_rate ,
@@ -177,6 +178,9 @@ def __init__(self,
177178
178179 self .type = "adamw"
179180
181+ if core .is_compiled_with_xpu ():
182+ self .type = "adam"
183+
180184 # Use _auxiliary_vars together with _set_auxiliary_var/_get_auxiliary_var to achieve that.
181185 self ._auxiliary_vars = dict ()
182186
@@ -189,7 +193,63 @@ def _get_auxiliary_var(self, key):
189193 else :
190194 return None
191195
196+ def _append_decoupled_weight_decay (self , block , param_and_grad ):
197+ """
198+ Add decoupled weight decay op.
199+ parameter = parameter - parameter * coeff * lr
200+ Args:
201+ block: block in which variable is to be created
202+ param_and_grad: (parameters, gradients) pairs,
203+ the parameters need to decay.
204+ Raises:
205+ Exception: The type of coeff and parameter is not consistent.
206+ """
207+ if isinstance (param_and_grad , dict ):
208+ param_and_grad = self ._update_param_group (param_and_grad )
209+ param , grad = param_and_grad
210+
211+ if self ._apply_decay_param_fun is not None \
212+ and not self ._apply_decay_param_fun (param .name ):
213+ return
214+
215+ if isinstance (self ._learning_rate , float ):
216+ learning_rate = self ._learning_rate
217+ else :
218+ # NOTE. We add this function to the _append_optimize_op(),
219+ # for we must make sure _create_param_lr() be called after
220+ # optimizer._create_global_learning_rate().
221+ learning_rate = self ._create_param_lr (param_and_grad )
222+
223+ with block .program ._optimized_guard (
224+ [param , grad ]), framework .name_scope ('weight decay' ):
225+ self ._params_name .add (param .name )
226+
227+ # If it has been calculated, the result will be reused.
228+ # NOTE(wangxi): In dygraph mode, apply_gradient will be executed
229+ # every step, so need clear _lr_to_coeff every step,
230+ # we do this in _create_optimization_pass
231+ decay_coeff = self ._lr_to_coeff .get (learning_rate , None )
232+ if decay_coeff is None :
233+ # NOTE(wangxi): for pipeline to set device:all
234+ with paddle .static .device_guard (None ):
235+ decay_coeff = 1.0 - learning_rate * self ._coeff
236+ self ._lr_to_coeff [learning_rate ] = decay_coeff
237+
238+ find_master = (self ._multi_precision and
239+ param .dtype == core .VarDesc .VarType .FP16 )
240+ if find_master :
241+ master_weight = self ._master_weights [param .name ]
242+ scaled_param = master_weight * decay_coeff
243+ paddle .fluid .layers .assign (
244+ input = scaled_param , output = master_weight )
245+ else :
246+ scaled_param = param * decay_coeff
247+ paddle .fluid .layers .assign (input = scaled_param , output = param )
248+
192249 def _append_optimize_op (self , block , param_and_grad ):
250+ if paddle .is_compiled_with_xpu ():
251+ self ._append_decoupled_weight_decay (block , param_and_grad )
252+ return super (AdamW , self )._append_optimize_op (block , param_and_grad )
193253
194254 assert isinstance (block , framework .Block )
195255 if isinstance (param_and_grad , dict ):
@@ -201,8 +261,6 @@ def _append_optimize_op(self, block, param_and_grad):
201261 if self ._apply_decay_param_fun is not None \
202262 and not self ._apply_decay_param_fun (param .name ):
203263 with_decay = False
204- else :
205- self ._params_name .add (param .name )
206264
207265 moment1 = self ._get_accumulator (self ._moment1_acc_str ,
208266 param_and_grad [0 ])
@@ -291,6 +349,13 @@ def _append_optimize_op(self, block, param_and_grad):
291349
292350 return adamw_op
293351
352+ def _create_optimization_pass (self , parameters_and_grads ):
353+ optimize_ops = super (
354+ AdamW , self )._create_optimization_pass (parameters_and_grads )
355+ # In dygraph mode, clear _lr_to_coeff after applied gradient
356+ self ._lr_to_coeff = dict ()
357+ return optimize_ops
358+
294359 def __str__ (self ):
295360 return " " .join (["Weight Decay, params:" , "," .join (self ._params_name )])
296361
0 commit comments