Skip to content

Commit 884011a

Browse files
authored
reverse xpu adamw to the combination of ops version. (#35286)
1 parent 572bad8 commit 884011a

File tree

1 file changed

+67
-2
lines changed

1 file changed

+67
-2
lines changed

python/paddle/optimizer/adamw.py

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ def __init__(self,
162162
self._params_name = set()
163163
self._apply_decay_param_fun = apply_decay_param_fun
164164
self._coeff = coeff
165+
self._lr_to_coeff = dict()
165166

166167
super(AdamW, self).__init__(
167168
learning_rate=learning_rate,
@@ -177,6 +178,9 @@ def __init__(self,
177178

178179
self.type = "adamw"
179180

181+
if core.is_compiled_with_xpu():
182+
self.type = "adam"
183+
180184
# Use _auxiliary_vars together with _set_auxiliary_var/_get_auxiliary_var to achieve that.
181185
self._auxiliary_vars = dict()
182186

@@ -189,7 +193,63 @@ def _get_auxiliary_var(self, key):
189193
else:
190194
return None
191195

196+
def _append_decoupled_weight_decay(self, block, param_and_grad):
197+
"""
198+
Add decoupled weight decay op.
199+
parameter = parameter - parameter * coeff * lr
200+
Args:
201+
block: block in which variable is to be created
202+
param_and_grad: (parameters, gradients) pairs,
203+
the parameters need to decay.
204+
Raises:
205+
Exception: The type of coeff and parameter is not consistent.
206+
"""
207+
if isinstance(param_and_grad, dict):
208+
param_and_grad = self._update_param_group(param_and_grad)
209+
param, grad = param_and_grad
210+
211+
if self._apply_decay_param_fun is not None \
212+
and not self._apply_decay_param_fun(param.name):
213+
return
214+
215+
if isinstance(self._learning_rate, float):
216+
learning_rate = self._learning_rate
217+
else:
218+
# NOTE. We add this function to the _append_optimize_op(),
219+
# for we must make sure _create_param_lr() be called after
220+
# optimizer._create_global_learning_rate().
221+
learning_rate = self._create_param_lr(param_and_grad)
222+
223+
with block.program._optimized_guard(
224+
[param, grad]), framework.name_scope('weight decay'):
225+
self._params_name.add(param.name)
226+
227+
# If it has been calculated, the result will be reused.
228+
# NOTE(wangxi): In dygraph mode, apply_gradient will be executed
229+
# every step, so need clear _lr_to_coeff every step,
230+
# we do this in _create_optimization_pass
231+
decay_coeff = self._lr_to_coeff.get(learning_rate, None)
232+
if decay_coeff is None:
233+
# NOTE(wangxi): for pipeline to set device:all
234+
with paddle.static.device_guard(None):
235+
decay_coeff = 1.0 - learning_rate * self._coeff
236+
self._lr_to_coeff[learning_rate] = decay_coeff
237+
238+
find_master = (self._multi_precision and
239+
param.dtype == core.VarDesc.VarType.FP16)
240+
if find_master:
241+
master_weight = self._master_weights[param.name]
242+
scaled_param = master_weight * decay_coeff
243+
paddle.fluid.layers.assign(
244+
input=scaled_param, output=master_weight)
245+
else:
246+
scaled_param = param * decay_coeff
247+
paddle.fluid.layers.assign(input=scaled_param, output=param)
248+
192249
def _append_optimize_op(self, block, param_and_grad):
250+
if paddle.is_compiled_with_xpu():
251+
self._append_decoupled_weight_decay(block, param_and_grad)
252+
return super(AdamW, self)._append_optimize_op(block, param_and_grad)
193253

194254
assert isinstance(block, framework.Block)
195255
if isinstance(param_and_grad, dict):
@@ -201,8 +261,6 @@ def _append_optimize_op(self, block, param_and_grad):
201261
if self._apply_decay_param_fun is not None \
202262
and not self._apply_decay_param_fun(param.name):
203263
with_decay = False
204-
else:
205-
self._params_name.add(param.name)
206264

207265
moment1 = self._get_accumulator(self._moment1_acc_str,
208266
param_and_grad[0])
@@ -291,6 +349,13 @@ def _append_optimize_op(self, block, param_and_grad):
291349

292350
return adamw_op
293351

352+
def _create_optimization_pass(self, parameters_and_grads):
353+
optimize_ops = super(
354+
AdamW, self)._create_optimization_pass(parameters_and_grads)
355+
# In dygraph mode, clear _lr_to_coeff after applied gradient
356+
self._lr_to_coeff = dict()
357+
return optimize_ops
358+
294359
def __str__(self):
295360
return " ".join(["Weight Decay, params:", ",".join(self._params_name)])
296361

0 commit comments

Comments
 (0)