@@ -257,46 +257,13 @@ def _create_regularization_of_grad(self, param, grad, regularization=None):
257257
258258 Function helper of append_regularization_ops.
259259 """
260- # (1) If no gradient or no regularization is specified, then we don't need to do anything.
261- # (2) If ParamAttr is set to L2Decay, we skip doing regularization here. And then we fused
260+ # If ParamAttr is set to L2Decay, we skip doing regularization here. And then we fused
262261 # L2Decay with momentum which can refer to _append_optimize_op below.
263- no_regularization = (not hasattr (param , 'regularizer' ) or (
264- hasattr (param , 'regularizer' ) and
265- param .regularizer is None )) and regularization is None
266- param_has_L2Decay = hasattr (param , 'regularizer' ) and isinstance (
267- param .regularizer , L2DecayRegularizer )
268- if grad is None or no_regularization or param_has_L2Decay :
262+ if hasattr (param , 'regularizer' ) and isinstance (param .regularizer ,
263+ L2DecayRegularizer ):
269264 return grad
270- regularization_term = None
271- if hasattr (param , 'regularizer' ) and param .regularizer is not None :
272- # Add variable for regularization term in grad block
273- regularization_term = param .regularizer (param , grad , grad .block )
274- elif regularization is not None :
275- regularization_term = regularization (param , grad , grad .block )
276-
277- assert regularization_term is not None
278-
279- new_grad = grad
280- if grad .type == core .VarDesc .VarType .SELECTED_ROWS :
281- # FIXME(zcd): If the grad is SELECTED_ROWS, after regularization,
282- # the grad's type and name will be changed. But the gradient's name
283- # is used in ParallelExecutor Reduce mode, so I add a flag for
284- # the new_grad here.
285- new_grad = grad .block .create_var (
286- name = grad .name + core .kNewGradSuffix (),
287- dtype = param .dtype ,
288- shape = param .shape ,
289- lod_level = param .lod_level ,
290- type = core .VarDesc .VarType .LOD_TENSOR )
291-
292- inputs = {"X" : [grad , regularization_term ]}
293- outputs = {"Out" : [new_grad ]}
294- if framework .in_dygraph_mode ():
295- new_grad = core .ops .sum ([grad , regularization_term ])
296- else :
297- grad .block .append_op (type = 'sum' , inputs = inputs , outputs = outputs )
298-
299- return new_grad
265+ return super (Momentum , self )._create_regularization_of_grad (
266+ param , grad , regularization )
300267
301268 def _append_optimize_op (self , block , param_and_grad ):
302269 assert isinstance (block , framework .Block )
0 commit comments