Skip to content

Commit c7fa29e

Browse files
committed
improve append_optimizer_op
1 parent 688df62 commit c7fa29e

File tree

2 files changed

+35
-23
lines changed

2 files changed

+35
-23
lines changed

python/paddle/fluid/tests/unittests/test_momentum_op.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -614,14 +614,14 @@ def test_momentum_static(self):
614614

615615

616616
class TestFusedMomentumWithDecayAPI(unittest.TestCase):
617-
def get_program(self, weight_attr):
617+
def get_program(self, weight_attr, bias_attr=False):
618618
main_program = paddle.static.Program()
619619
startup_program = paddle.static.Program()
620620
with paddle.static.program_guard(
621621
main_program=main_program, startup_program=startup_program):
622622
x = paddle.static.data(name='x', shape=[10, 10])
623623
linear = paddle.nn.Linear(
624-
10, 10, weight_attr=weight_attr, bias_attr=False)
624+
10, 10, weight_attr=weight_attr, bias_attr=bias_attr)
625625
out = linear(x)
626626
loss = paddle.mean(out)
627627
optimizer = paddle.optimizer.Momentum(
@@ -637,7 +637,7 @@ def test_param_has_l2decay(self):
637637
name="weight",
638638
initializer=paddle.nn.initializer.Constant(value=0.5),
639639
regularizer=paddle.regularizer.L2Decay(0.1))
640-
program = self.get_program(weight_attr)
640+
program = self.get_program(weight_attr, bias_attr=False)
641641
ops = program.global_block().ops
642642

643643
self.assertEqual(ops[-1].attr('regularization_method'), 'l2_decay')
@@ -652,21 +652,30 @@ def test_param_has_l1decay(self):
652652
name="weight",
653653
initializer=paddle.nn.initializer.Constant(value=0.5),
654654
regularizer=paddle.regularizer.L1Decay(0.1))
655-
program = self.get_program(weight_attr)
655+
bias_attr = paddle.ParamAttr(
656+
name="bias",
657+
initializer=paddle.nn.initializer.Constant(value=0.),
658+
regularizer=None)
659+
program = self.get_program(weight_attr, bias_attr)
656660
ops = program.global_block().ops
657-
self.assertEqual(ops[-1].attr('regularization_method'), '')
658-
self.assertEqual(ops[-1].attr('regularization_coeff'), 0)
659-
self.assertEqual(ops[-2].type, 'sum')
660-
self.assertEqual(ops[-3].type, 'scale')
661-
self.assertEqual(ops[-4].type, 'sign')
662661

663-
def test_param_regularizer_is_none(self):
662+
self.assertEqual(ops[-1].type, 'momentum')
663+
self.assertEqual(ops[-2].type, 'momentum')
664+
self.assertEqual(ops[-3].type, 'sum')
665+
self.assertEqual(ops[-4].type, 'scale')
666+
self.assertEqual(ops[-5].type, 'sign')
667+
self.assertEqual(ops[-6].type, 'matmul_grad')
668+
if 'weight' in ops[-1].input('Param'):
669+
self.assertEqual(ops[-1].attr('regularization_method'), '')
670+
self.assertEqual(ops[-1].attr('regularization_coeff'), 0)
671+
if 'bias' in ops[-2].input('Param'):
672+
self.assertEqual(ops[-2].attr('regularization_method'), 'l2_decay')
673+
self.assertEqual(ops[-2].attr('regularization_coeff'),
674+
np.float32(0.5))
675+
676+
def test_param_has_no_regularizer(self):
664677
paddle.enable_static()
665-
weight_attr = paddle.ParamAttr(
666-
name="weight",
667-
initializer=paddle.nn.initializer.Constant(value=0.5),
668-
regularizer=None)
669-
program = self.get_program(weight_attr)
678+
program = self.get_program(weight_attr=None)
670679
ops = program.global_block().ops
671680
self.assertEqual(ops[-1].attr('regularization_method'), 'l2_decay')
672681
self.assertEqual(ops[-1].attr('regularization_coeff'), np.float32(0.5))

python/paddle/optimizer/momentum.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -274,16 +274,19 @@ def _append_optimize_op(self, block, param_and_grad):
274274
param_and_grad[0])
275275
lr = self._create_param_lr(param_and_grad)
276276

277+
# For fusion of momentum and l2decay
277278
param = param_and_grad[0]
279+
regularization_method = self._regularization_method
280+
regularization_coeff = self._regularization_coeff
278281
if hasattr(param, 'regularizer'):
279282
# we skip param's l2decay before, so fuse it with momentum here.
280283
if isinstance(param.regularizer, L2DecayRegularizer):
281-
self._regularization_method = "l2_decay"
282-
self._regularization_coeff = param.regularizer._regularization_coeff
284+
regularization_method = "l2_decay"
285+
regularization_coeff = param.regularizer._regularization_coeff
283286
# the param's regularization has been done before, we avoid do l2decay in momentum.
284287
elif param.regularizer is not None:
285-
self._regularization_method = ""
286-
self._regularization_coeff = 0
288+
regularization_method = ""
289+
regularization_coeff = 0
287290

288291
if framework.in_dygraph_mode():
289292
if isinstance(param_and_grad, dict):
@@ -292,8 +295,8 @@ def _append_optimize_op(self, block, param_and_grad):
292295
param_and_grad[0], param_and_grad[1], velocity_acc, lr,
293296
param_and_grad[0], velocity_acc, 'mu', self._momentum,
294297
'use_nesterov', self._use_nesterov, 'regularization_method',
295-
self._regularization_method, 'regularization_coeff',
296-
self._regularization_coeff)
298+
regularization_method, 'regularization_coeff',
299+
regularization_coeff)
297300
return None
298301

299302
find_master = self._multi_precision and param_and_grad[
@@ -304,8 +307,8 @@ def _append_optimize_op(self, block, param_and_grad):
304307
attrs = {
305308
"mu": self._momentum,
306309
"use_nesterov": self._use_nesterov,
307-
"regularization_method": self._regularization_method,
308-
"regularization_coeff": self._regularization_coeff,
310+
"regularization_method": regularization_method,
311+
"regularization_coeff": regularization_coeff,
309312
"multi_precision": find_master,
310313
"rescale_grad": self._rescale_grad
311314
}

0 commit comments

Comments
 (0)