Skip to content

Commit f16c7e9

Browse files
committed
fix cpu device, test=allcase
1 parent c5a067e commit f16c7e9

File tree

2 files changed

+10
-7
lines changed

2 files changed

+10
-7
lines changed

python/paddle/fluid/optimizer.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4634,6 +4634,9 @@ def _add_op_device_attr_for_op(self, op, idx, block):
46344634
op.type == 'elementwise_div'):
46354635
device = f"{self._device}:all"
46364636
op._set_attr(self._op_device_key, device)
4637+
elif self._is_weight_decay_op(op) and op.type == 'scale':
4638+
# set AdamW decay_coeff to device:all
4639+
op._set_attr(self._op_device_key, f"{self._device}:all")
46374640
elif op.type == "alloc_float_status":
46384641
op._set_attr(self._op_device_key, f"{self._device}:all")
46394642
else:
@@ -5267,6 +5270,11 @@ def _is_regularization_op(self, op):
52675270
return op.desc.has_attr("op_namescope") \
52685271
and op.desc.attr("op_namescope").startswith("/regularization")
52695272

5273+
def _is_weight_decay_op(self, op):
5274+
# in AdamW namescope is /optimizer_*/weight decay/
5275+
return op.desc.has_attr("op_namescope") \
5276+
and 'weight decay' in op.desc.attr("op_namescope")
5277+
52705278
def _get_input_output_info(self, block):
52715279
'''
52725280
Get info of op input and output.

python/paddle/optimizer/adamw.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -161,12 +161,6 @@ def __init__(self,
161161
self._coeff = coeff
162162
self._lr_to_coeff = dict()
163163

164-
self._device = "cpu"
165-
if core.is_compiled_with_npu():
166-
self._device = "npu"
167-
elif core.is_compiled_with_cuda():
168-
self._device = "gpu"
169-
170164
super(AdamW, self).__init__(
171165
learning_rate=learning_rate,
172166
parameters=parameters,
@@ -218,7 +212,8 @@ def _append_decoupled_weight_decay(self, block, param_and_grad):
218212
# we do this in _create_optimization_pass
219213
decay_coeff = self._lr_to_coeff.get(learning_rate, None)
220214
if decay_coeff is None:
221-
with paddle.static.device_guard("{}:all".format(self._device)):
215+
# NOTE(wangxi): for pipeline to set device:all
216+
with paddle.static.device_guard(None):
222217
decay_coeff = 1.0 - learning_rate * self._coeff
223218
self._lr_to_coeff[learning_rate] = decay_coeff
224219

0 commit comments

Comments
 (0)