Skip to content

Commit cb9dfe1

Browse files
committed
refine
1 parent 8bd8185 commit cb9dfe1

File tree

6 files changed

+289
-200
lines changed

6 files changed

+289
-200
lines changed

python/paddle/fluid/optimizer.py

Lines changed: 93 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
from .initializer import Constant
3434
from .layer_helper import LayerHelper
3535
from .layers import ops
36-
from .regularizer import append_regularization_ops
3736
from .dygraph import base as imperative_base
3837
from .dygraph import no_grad
3938
from .dygraph.learning_rate_scheduler import LearningRateDecay, _LearningRateEpochDecay
@@ -884,6 +883,93 @@ def backward(self,
884883
act_no_grad_set, callbacks)
885884
return params_grads
886885

886+
def _create_regularization_of_grad(self, param, grad, regularization=None):
887+
""" Create and add backward regularization Operators
888+
889+
Function helper of append_regularization_ops.
890+
"""
891+
# If no gradient or no regularization is specified, then we don't need to do anything
892+
if grad is None or ((not hasattr(param, 'regularizer') or
893+
(hasattr(param, 'regularizer') and
894+
param.regularizer is None)) and
895+
regularization is None):
896+
return grad
897+
regularization_term = None
898+
if hasattr(param, 'regularizer') and param.regularizer is not None:
899+
# Add variable for regularization term in grad block
900+
regularization_term = param.regularizer(param, grad, grad.block)
901+
elif regularization is not None:
902+
regularization_term = regularization(param, grad, grad.block)
903+
904+
assert regularization_term is not None
905+
906+
new_grad = grad
907+
if grad.type == core.VarDesc.VarType.SELECTED_ROWS:
908+
# FIXME(zcd): If the grad is SELECTED_ROWS, after regularization,
909+
# the grad's type and name will be changed. But the gradient's name
910+
# is used in ParallelExecutor Reduce mode, so I add a flag for
911+
# the new_grad here.
912+
new_grad = grad.block.create_var(
913+
name=grad.name + core.kNewGradSuffix(),
914+
dtype=param.dtype,
915+
shape=param.shape,
916+
lod_level=param.lod_level,
917+
type=core.VarDesc.VarType.LOD_TENSOR)
918+
919+
inputs = {"X": [grad, regularization_term]}
920+
outputs = {"Out": [new_grad]}
921+
if framework.in_dygraph_mode():
922+
new_grad = core.ops.sum([grad, regularization_term])
923+
else:
924+
grad.block.append_op(type='sum', inputs=inputs, outputs=outputs)
925+
926+
return new_grad
927+
928+
def append_regularization_ops(self,
929+
parameters_and_grads,
930+
regularization=None):
931+
r"""Create and add backward regularization Operators
932+
933+
Creates and adds backward regularization operators in the BlockDesc.
934+
This will add gradients of the regularizer function to the gradients
935+
of the parameters and return these modified gradients. This is the
936+
same as implementing weight decay in optimizers for regularization.
937+
938+
Args:
939+
parameters_and_grads: A list of (parameters, gradients) pairs
940+
that need to be regularized.
941+
regularization: A global regularizer. If the parameter is not
942+
set. It will be applied with regularizer.
943+
944+
Returns:
945+
list[(Variable, Variable)]: list of (parameters, gradients) \
946+
pair with the regularized gradient
947+
948+
Raises:
949+
Exception: Unknown regularization type
950+
"""
951+
params_and_grads = []
952+
if framework.in_dygraph_mode():
953+
for param, grad in parameters_and_grads:
954+
new_grad = self._create_regularization_of_grad(param, grad,
955+
regularization)
956+
params_and_grads.append((param, new_grad))
957+
else:
958+
repeate_regularizer = False
959+
with framework.name_scope('regularization'):
960+
for param, grad in parameters_and_grads:
961+
if not repeate_regularizer and param.regularizer is not None and regularization is not None:
962+
repeate_regularizer = True
963+
logging.info(
964+
"If regularizer of a Parameter has been set by 'fluid.ParamAttr' or 'fluid.WeightNormParamAttr' already. "
965+
"The Regularization[%s] in Optimizer will not take effect, and it will only be applied to other Parameters!"
966+
% regularization.__str__())
967+
with param.block.program._optimized_guard([param, grad]):
968+
new_grad = self._create_regularization_of_grad(
969+
param, grad, regularization)
970+
params_and_grads.append((param, new_grad))
971+
return params_and_grads
972+
887973
def apply_gradients(self, params_grads):
888974
"""
889975
Second part of `minimize`, appending optimization operators for
@@ -916,8 +1002,8 @@ def apply_gradients(self, params_grads):
9161002
params_grads = append_gradient_clip_ops(params_grads)
9171003

9181004
# Add regularization if any
919-
params_grads = append_regularization_ops(params_grads,
920-
self.regularization)
1005+
params_grads = self.append_regularization_ops(params_grads,
1006+
self.regularization)
9211007

9221008
optimize_ops = self._create_optimization_pass(params_grads)
9231009
return optimize_ops
@@ -939,8 +1025,8 @@ def apply_optimize(self, loss, startup_program, params_grads):
9391025
framework.default_startup_program()):
9401026
if self._grad_clip is not None:
9411027
params_grads = self._grad_clip(params_grads)
942-
params_grads = append_regularization_ops(params_grads,
943-
self.regularization)
1028+
params_grads = self.append_regularization_ops(
1029+
params_grads, self.regularization)
9441030
optimize_ops = self._create_optimization_pass(params_grads)
9451031
else:
9461032
program = loss.block.program
@@ -1674,8 +1760,8 @@ def apply_gradients(self, params_grads):
16741760
not_dgc_params_grads = append_gradient_clip_ops(
16751761
not_dgc_params_grads)
16761762

1677-
not_dgc_params_grads = append_regularization_ops(not_dgc_params_grads,
1678-
self.regularization)
1763+
not_dgc_params_grads = self.append_regularization_ops(
1764+
not_dgc_params_grads, self.regularization)
16791765

16801766
params_grads = not_dgc_params_grads + dgc_params_grads
16811767
params_grads = sorted(params_grads, key=lambda x: x[0].name)

python/paddle/fluid/regularizer.py

Lines changed: 0 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -22,92 +22,6 @@
2222
__all__ = ['L1Decay', 'L2Decay', 'L1DecayRegularizer', 'L2DecayRegularizer']
2323

2424

25-
def _create_regularization_of_grad(param, grad, regularization=None):
26-
""" Create and add backward regularization Operators
27-
28-
Function helper of append_regularization_ops.
29-
"""
30-
# If no gradient or no regularization is specified, then we don't need to do anything
31-
if grad is None or ((not hasattr(param, 'regularizer') or (
32-
hasattr(param, 'regularizer') and param.regularizer is None)) and
33-
regularization is None):
34-
return grad
35-
regularization_term = None
36-
if hasattr(param, 'regularizer') and param.regularizer is not None:
37-
# Add variable for regularization term in grad block
38-
regularization_term = param.regularizer(param, grad, grad.block)
39-
elif regularization is not None:
40-
regularization_term = regularization(param, grad, grad.block)
41-
42-
assert regularization_term is not None
43-
44-
new_grad = grad
45-
if grad.type == core.VarDesc.VarType.SELECTED_ROWS:
46-
# FIXME(zcd): If the grad is SELECTED_ROWS, after regularization,
47-
# the grad's type and name will be changed. But the gradient's name
48-
# is used in ParallelExecutor Reduce mode, so I add a flag for
49-
# the new_grad here.
50-
new_grad = grad.block.create_var(
51-
name=grad.name + core.kNewGradSuffix(),
52-
dtype=param.dtype,
53-
shape=param.shape,
54-
lod_level=param.lod_level,
55-
type=core.VarDesc.VarType.LOD_TENSOR)
56-
57-
inputs = {"X": [grad, regularization_term]}
58-
outputs = {"Out": [new_grad]}
59-
if in_dygraph_mode():
60-
new_grad = core.ops.sum([grad, regularization_term])
61-
else:
62-
grad.block.append_op(type='sum', inputs=inputs, outputs=outputs)
63-
64-
return new_grad
65-
66-
67-
def append_regularization_ops(parameters_and_grads, regularization=None):
68-
r"""Create and add backward regularization Operators
69-
70-
Creates and adds backward regularization operators in the BlockDesc.
71-
This will add gradients of the regularizer function to the gradients
72-
of the parameters and return these modified gradients. This is the
73-
same as implementing weight decay in optimizers for regularization.
74-
75-
Args:
76-
parameters_and_grads: A list of (parameters, gradients) pairs
77-
that need to be regularized.
78-
regularization: A global regularizer. If the parameter is not
79-
set. It will be applied with regularizer.
80-
81-
Returns:
82-
list[(Variable, Variable)]: list of (parameters, gradients) \
83-
pair with the regularized gradient
84-
85-
Raises:
86-
Exception: Unknown regularization type
87-
"""
88-
params_and_grads = []
89-
if in_dygraph_mode():
90-
for param, grad in parameters_and_grads:
91-
new_grad = _create_regularization_of_grad(param, grad,
92-
regularization)
93-
params_and_grads.append((param, new_grad))
94-
else:
95-
repeate_regularizer = False
96-
with framework.name_scope('regularization'):
97-
for param, grad in parameters_and_grads:
98-
if not repeate_regularizer and param.regularizer is not None and regularization is not None:
99-
repeate_regularizer = True
100-
logging.info(
101-
"If regularizer of a Parameter has been set by 'fluid.ParamAttr' or 'fluid.WeightNormParamAttr' already. "
102-
"The Regularization[%s] in Optimizer will not take effect, and it will only be applied to other Parameters!"
103-
% regularization.__str__())
104-
with param.block.program._optimized_guard([param, grad]):
105-
new_grad = _create_regularization_of_grad(param, grad,
106-
regularization)
107-
params_and_grads.append((param, new_grad))
108-
return params_and_grads
109-
110-
11125
class WeightDecayRegularizer(object):
11226
"""Base class for weight decay regularizers
11327

python/paddle/fluid/tests/unittests/test_momentum_op.py

Lines changed: 52 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -614,41 +614,65 @@ def test_momentum_static(self):
614614

615615

616616
class TestFusedMomentumWithDecayAPI(unittest.TestCase):
617-
def get_loss_and_optimizer(self, input):
617+
def get_program(self, weight_attr):
618+
main_program = paddle.static.Program()
619+
startup_program = paddle.static.Program()
620+
with paddle.static.program_guard(
621+
main_program=main_program, startup_program=startup_program):
622+
x = paddle.static.data(name='x', shape=[10, 10])
623+
linear = paddle.nn.Linear(
624+
10, 10, weight_attr=weight_attr, bias_attr=False)
625+
out = linear(x)
626+
loss = paddle.mean(out)
627+
optimizer = paddle.optimizer.Momentum(
628+
learning_rate=0.01,
629+
momentum=0.9,
630+
weight_decay=paddle.regularizer.L2Decay(0.5))
631+
optimizer.minimize(loss)
632+
return main_program
633+
634+
def test_param_has_l2decay(self):
635+
paddle.enable_static()
618636
weight_attr = paddle.ParamAttr(
619637
name="weight",
620638
initializer=paddle.nn.initializer.Constant(value=0.5),
621639
regularizer=paddle.regularizer.L2Decay(0.1))
622-
linear = paddle.nn.Linear(10, 10, weight_attr=weight_attr)
623-
out = linear(input)
624-
loss = paddle.mean(out)
625-
626-
momentum = paddle.optimizer.Momentum(
627-
learning_rate=0.01,
628-
momentum=0.9,
629-
parameters=linear.parameters(),
630-
weight_decay=paddle.regularizer.L1Decay(0.5),
631-
grad_clip=paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0))
632-
return loss, momentum
640+
program = self.get_program(weight_attr)
641+
ops = program.global_block().ops
633642

634-
def test_dygraph(self):
635-
paddle.disable_static()
636-
inp = np.random.uniform(-0.1, 0.1, [10, 10]).astype("float32")
637-
x = paddle.to_tensor(inp)
638-
loss, optimizer = self.get_loss_and_optimizer(x)
639-
loss.backward()
640-
optimizer.step()
641-
optimizer.clear_grad()
643+
self.assertEqual(ops[-1].attr('regularization_method'), 'l2_decay')
644+
self.assertEqual(ops[-1].attr('regularization_coeff'), np.float32(0.1))
645+
for i in range(len(ops)):
646+
self.assertTrue('sum' not in ops[i].type)
647+
self.assertTrue('scale' not in ops[i].type)
642648

643-
def test_static(self):
649+
def test_param_has_l1decay(self):
644650
paddle.enable_static()
645-
inp = np.random.uniform(-0.1, 0.1, [10, 10]).astype("float32")
646-
x = paddle.static.data(name='x', shape=[10, 10])
647-
loss, optimizer = self.get_loss_and_optimizer(x)
648-
optimizer.minimize(loss)
649-
exe = paddle.static.Executor()
650-
exe.run(paddle.static.default_startup_program())
651-
exe.run(feed={"x": inp})
651+
weight_attr = paddle.ParamAttr(
652+
name="weight",
653+
initializer=paddle.nn.initializer.Constant(value=0.5),
654+
regularizer=paddle.regularizer.L1Decay(0.1))
655+
program = self.get_program(weight_attr)
656+
ops = program.global_block().ops
657+
self.assertEqual(ops[-1].attr('regularization_method'), '')
658+
self.assertEqual(ops[-1].attr('regularization_coeff'), 0)
659+
self.assertEqual(ops[-2].type, 'sum')
660+
self.assertEqual(ops[-3].type, 'scale')
661+
self.assertEqual(ops[-4].type, 'sign')
662+
663+
def test_param_regularizer_is_none(self):
664+
paddle.enable_static()
665+
weight_attr = paddle.ParamAttr(
666+
name="weight",
667+
initializer=paddle.nn.initializer.Constant(value=0.5),
668+
regularizer=None)
669+
program = self.get_program(weight_attr)
670+
ops = program.global_block().ops
671+
self.assertEqual(ops[-1].attr('regularization_method'), 'l2_decay')
672+
self.assertEqual(ops[-1].attr('regularization_coeff'), np.float32(0.5))
673+
for i in range(len(ops)):
674+
self.assertTrue('sum' not in ops[i].type)
675+
self.assertTrue('scale' not in ops[i].type)
652676

653677

654678
class TestMomentumOpVsMomentumOpWithDecayAPI(unittest.TestCase):

python/paddle/fluid/tests/unittests/test_regularizer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def test_l2decay_regularizer(self):
5959
params_grads = append_backward(mean_out)
6060
self.assertEqual(len(params_grads), 1)
6161
count_ops = len(block.ops)
62+
optimizer = paddle.optimizer.Adam()
6263
params_grads = optimizer.append_regularization_ops(params_grads)
6364
self.assertEqual(len(params_grads), 1)
6465
self.assertEqual(len(block.ops), count_ops + 2)
@@ -97,6 +98,7 @@ def test_l2decay_regularizer(self):
9798
params_grads = append_backward(mean_out)
9899
self.assertEqual(len(params_grads), 1)
99100
count_ops = len(block.ops)
101+
optimizer = paddle.optimizer.Adam()
100102
params_grads = optimizer.append_regularization_ops(params_grads)
101103
self.assertEqual(len(params_grads), 1)
102104
self.assertEqual(len(block.ops), count_ops + 3)

0 commit comments

Comments
 (0)