Skip to content

Commit 688df62

Browse files
committed
refine _create_regularization_of_grad of momentum
1 parent cb9dfe1 commit 688df62

File tree

1 file changed

+5
-38
lines changed

1 file changed

+5
-38
lines changed

python/paddle/optimizer/momentum.py

Lines changed: 5 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -257,46 +257,13 @@ def _create_regularization_of_grad(self, param, grad, regularization=None):
257257
258258
Function helper of append_regularization_ops.
259259
"""
260-
# (1) If no gradient or no regularization is specified, then we don't need to do anything.
261-
# (2) If ParamAttr is set to L2Decay, we skip doing regularization here. And then we fused
260+
# If ParamAttr is set to L2Decay, we skip doing regularization here. And then we fused
262261
# L2Decay with momentum which can refer to _append_optimize_op below.
263-
no_regularization = (not hasattr(param, 'regularizer') or (
264-
hasattr(param, 'regularizer') and
265-
param.regularizer is None)) and regularization is None
266-
param_has_L2Decay = hasattr(param, 'regularizer') and isinstance(
267-
param.regularizer, L2DecayRegularizer)
268-
if grad is None or no_regularization or param_has_L2Decay:
262+
if hasattr(param, 'regularizer') and isinstance(param.regularizer,
263+
L2DecayRegularizer):
269264
return grad
270-
regularization_term = None
271-
if hasattr(param, 'regularizer') and param.regularizer is not None:
272-
# Add variable for regularization term in grad block
273-
regularization_term = param.regularizer(param, grad, grad.block)
274-
elif regularization is not None:
275-
regularization_term = regularization(param, grad, grad.block)
276-
277-
assert regularization_term is not None
278-
279-
new_grad = grad
280-
if grad.type == core.VarDesc.VarType.SELECTED_ROWS:
281-
# FIXME(zcd): If the grad is SELECTED_ROWS, after regularization,
282-
# the grad's type and name will be changed. But the gradient's name
283-
# is used in ParallelExecutor Reduce mode, so I add a flag for
284-
# the new_grad here.
285-
new_grad = grad.block.create_var(
286-
name=grad.name + core.kNewGradSuffix(),
287-
dtype=param.dtype,
288-
shape=param.shape,
289-
lod_level=param.lod_level,
290-
type=core.VarDesc.VarType.LOD_TENSOR)
291-
292-
inputs = {"X": [grad, regularization_term]}
293-
outputs = {"Out": [new_grad]}
294-
if framework.in_dygraph_mode():
295-
new_grad = core.ops.sum([grad, regularization_term])
296-
else:
297-
grad.block.append_op(type='sum', inputs=inputs, outputs=outputs)
298-
299-
return new_grad
265+
return super(Momentum, self)._create_regularization_of_grad(
266+
param, grad, regularization)
300267

301268
def _append_optimize_op(self, block, param_and_grad):
302269
assert isinstance(block, framework.Block)

0 commit comments

Comments
 (0)