@@ -614,14 +614,14 @@ def test_momentum_static(self):
614614
615615
616616class TestFusedMomentumWithDecayAPI (unittest .TestCase ):
617- def get_program (self , weight_attr ):
617+ def get_program (self , weight_attr , bias_attr = False ):
618618 main_program = paddle .static .Program ()
619619 startup_program = paddle .static .Program ()
620620 with paddle .static .program_guard (
621621 main_program = main_program , startup_program = startup_program ):
622622 x = paddle .static .data (name = 'x' , shape = [10 , 10 ])
623623 linear = paddle .nn .Linear (
624- 10 , 10 , weight_attr = weight_attr , bias_attr = False )
624+ 10 , 10 , weight_attr = weight_attr , bias_attr = bias_attr )
625625 out = linear (x )
626626 loss = paddle .mean (out )
627627 optimizer = paddle .optimizer .Momentum (
@@ -637,7 +637,7 @@ def test_param_has_l2decay(self):
637637 name = "weight" ,
638638 initializer = paddle .nn .initializer .Constant (value = 0.5 ),
639639 regularizer = paddle .regularizer .L2Decay (0.1 ))
640- program = self .get_program (weight_attr )
640+ program = self .get_program (weight_attr , bias_attr = False )
641641 ops = program .global_block ().ops
642642
643643 self .assertEqual (ops [- 1 ].attr ('regularization_method' ), 'l2_decay' )
@@ -652,21 +652,30 @@ def test_param_has_l1decay(self):
652652 name = "weight" ,
653653 initializer = paddle .nn .initializer .Constant (value = 0.5 ),
654654 regularizer = paddle .regularizer .L1Decay (0.1 ))
655- program = self .get_program (weight_attr )
655+ bias_attr = paddle .ParamAttr (
656+ name = "bias" ,
657+ initializer = paddle .nn .initializer .Constant (value = 0. ),
658+ regularizer = None )
659+ program = self .get_program (weight_attr , bias_attr )
656660 ops = program .global_block ().ops
657- self .assertEqual (ops [- 1 ].attr ('regularization_method' ), '' )
658- self .assertEqual (ops [- 1 ].attr ('regularization_coeff' ), 0 )
659- self .assertEqual (ops [- 2 ].type , 'sum' )
660- self .assertEqual (ops [- 3 ].type , 'scale' )
661- self .assertEqual (ops [- 4 ].type , 'sign' )
662661
663- def test_param_regularizer_is_none (self ):
662+ self .assertEqual (ops [- 1 ].type , 'momentum' )
663+ self .assertEqual (ops [- 2 ].type , 'momentum' )
664+ self .assertEqual (ops [- 3 ].type , 'sum' )
665+ self .assertEqual (ops [- 4 ].type , 'scale' )
666+ self .assertEqual (ops [- 5 ].type , 'sign' )
667+ self .assertEqual (ops [- 6 ].type , 'matmul_grad' )
668+ if 'weight' in ops [- 1 ].input ('Param' ):
669+ self .assertEqual (ops [- 1 ].attr ('regularization_method' ), '' )
670+ self .assertEqual (ops [- 1 ].attr ('regularization_coeff' ), 0 )
671+ if 'bias' in ops [- 2 ].input ('Param' ):
672+ self .assertEqual (ops [- 2 ].attr ('regularization_method' ), 'l2_decay' )
673+ self .assertEqual (ops [- 2 ].attr ('regularization_coeff' ),
674+ np .float32 (0.5 ))
675+
676+ def test_param_has_no_regularizer (self ):
664677 paddle .enable_static ()
665- weight_attr = paddle .ParamAttr (
666- name = "weight" ,
667- initializer = paddle .nn .initializer .Constant (value = 0.5 ),
668- regularizer = None )
669- program = self .get_program (weight_attr )
678+ program = self .get_program (weight_attr = None )
670679 ops = program .global_block ().ops
671680 self .assertEqual (ops [- 1 ].attr ('regularization_method' ), 'l2_decay' )
672681 self .assertEqual (ops [- 1 ].attr ('regularization_coeff' ), np .float32 (0.5 ))
0 commit comments