2121
2222import paddle
2323from paddle import nn
24+ from paddle .pir .core import ParameterMeta
2425
26+ DELTA = 0.00001
2527
26- def get_uniform_min_and_max (weight ):
27- min_value = np .min (weight )
28- max_value = np .max (weight )
29- return min_value , max_value
3028
31-
32- class TestKaimingUniform (unittest .TestCase ):
29+ class TestKaimingUniformFunc (unittest .TestCase ):
3330 def _test_kaiming_uniform_common (self , tensor ):
3431 init = paddle .nn .init .kaiming_uniform_
3532 init (tensor , a = 0 , mode = "fan_in" , nonlinearity = "leaky_relu" )
@@ -54,12 +51,54 @@ def _is_uniform(self, tensor, a, b):
5451 def _random_float (self , a , b ):
5552 return (b - a ) * random .random () + a
5653
54+ def calculate_gain (self , nonlinearity , param ):
55+ recommended_gain = {
56+ 'sigmoid' : 1 ,
57+ 'linear' : 1 ,
58+ 'conv1d' : 1 ,
59+ 'conv2d' : 1 ,
60+ 'conv3d' : 1 ,
61+ 'conv1d_transpose' : 1 ,
62+ 'conv_transpose1d' : 1 ,
63+ 'conv2d_transpose' : 1 ,
64+ 'conv_transpose2d' : 1 ,
65+ 'conv3d_transpose' : 1 ,
66+ 'conv_transpose3d' : 1 ,
67+ 'tanh' : 5.0 / 3 ,
68+ 'relu' : math .sqrt (2.0 ),
69+ 'leaky_relu' : math .sqrt (2.0 / (1 + param ** 2 )),
70+ 'selu' : 3.0 / 4 ,
71+ }
72+ return recommended_gain [nonlinearity ]
73+
74+ def test_kaiming_uniform_nonlinearity (self ):
75+ for nonlinearity in [
76+ 'conv_transpose1d' ,
77+ 'conv_transpose2d' ,
78+ 'conv_transpose3d' ,
79+ 'relu' ,
80+ 'leaky_relu' ,
81+ ]:
82+ input_tensor = paddle .zeros ([1024 , 512 ])
83+ paddle .nn .init .kaiming_uniform_ (
84+ input_tensor , nonlinearity = nonlinearity
85+ )
86+
87+ fan_in = input_tensor .shape [0 ]
88+
89+ expected_std = self .calculate_gain (
90+ nonlinearity = nonlinearity , param = 0
91+ )
92+
93+ bounds = expected_std * math .sqrt (3.0 / float (fan_in ))
94+ assert self ._is_uniform (input_tensor , - bounds , bounds )
95+
5796 def test_kaiming_uniform (self ):
5897 for use_a in [True , False ]:
59- for dims in [2 , 4 ]:
98+ for dims in [2 , 3 , 4 ]:
6099 for mode in ["fan_in" , "fan_out" ]:
61100 input_tensor = self ._create_random_nd_tensor (
62- dims , size_min = 20 , size_max = 25
101+ dims , size_min = 20 , size_max = 108
63102 )
64103 if use_a :
65104 a = self ._random_float (0.1 , 2 )
@@ -86,11 +125,142 @@ def test_kaiming_uniform(self):
86125 n = fan_in
87126 else :
88127 n = fan_out
89-
90- expected_std = math .sqrt (2.0 / ((1 + a ** 2 ) * n ))
91- bounds = expected_std * math .sqrt (3.0 )
128+ expected_std = self .calculate_gain (
129+ nonlinearity = 'leaky_relu' , param = a
130+ )
131+ bounds = expected_std * math .sqrt (3.0 / float (n ))
92132 assert self ._is_uniform (input_tensor , - bounds , bounds )
93133
134+ @unittest .skipIf (
135+ not paddle .is_compiled_with_cuda (), "core is not compiled with CUDA"
136+ )
137+ def test_kaiming_uniform_fp16 (self ):
138+ input_tensor = paddle .zeros ([1024 , 512 ], dtype = 'float16' )
139+ paddle .nn .init .kaiming_uniform_ (input_tensor )
140+ fan_in = input_tensor .shape [0 ]
141+
142+ expected_std = self .calculate_gain (nonlinearity = 'leaky_relu' , param = 0 )
143+
144+ bounds = expected_std * math .sqrt (3.0 / float (fan_in ))
145+ assert self ._is_uniform (input_tensor , - bounds , bounds )
146+ assert input_tensor .dtype == paddle .float16
147+
148+
149+ class TestKaimingUniformFuncPir (unittest .TestCase ):
150+ def setUp (self ):
151+ self .init_uniform_op_name = 'pd_op.uniform'
152+
153+ def get_operand_definition_op_attrs (self , cur_op , operand_name , attr_name ):
154+ input_names = cur_op .get_input_names ()
155+ self .assertIn (operand_name , input_names )
156+ attr = (
157+ cur_op .operand (input_names .index (operand_name ))
158+ .source ()
159+ .get_defining_op ()
160+ .attrs ()[attr_name ]
161+ )
162+ return attr
163+
164+ def get_init_ops_by_op_name (self , block , op_name ):
165+ checked_ops = []
166+ for op in block .ops :
167+ # get init op
168+ if op_name == op .name ():
169+ checked_ops .append (op )
170+ return checked_ops
171+
172+ def test_kaiming_uniform_ (self ):
173+ with paddle .pir_utils .IrGuard ():
174+ main = paddle .static .Program ()
175+ with paddle .static .program_guard (main , paddle .static .Program ()):
176+ parameter_meta = ParameterMeta ([1024 , 512 ], paddle .float32 )
177+ init_result = paddle .nn .init .kaiming_uniform_ (
178+ parameter_meta , block = main .global_block ()
179+ )
180+ block = main .global_block ()
181+ checked_ops = self .get_init_ops_by_op_name (
182+ block , self .init_uniform_op_name
183+ )
184+ self .assertEqual (len (checked_ops ), 1 )
185+ init_op = checked_ops [0 ]
186+ limit = np .sqrt (6.0 / init_result .shape [0 ])
187+
188+ min = self .get_operand_definition_op_attrs (
189+ init_op , "min" , "value"
190+ )
191+ max = self .get_operand_definition_op_attrs (
192+ init_op , "max" , "value"
193+ )
194+ self .assertAlmostEqual (min , - limit , delta = DELTA )
195+ self .assertAlmostEqual (max , limit , delta = DELTA )
196+ self .assertEqual (init_op .attrs ()['seed' ], 0 )
197+
198+ def test_kaiming_uniform_conv (self ):
199+ with paddle .pir_utils .IrGuard ():
200+ main = paddle .static .Program ()
201+ with paddle .static .program_guard (main , paddle .static .Program ()):
202+ parameter_meta = ParameterMeta ([5 , 10 , 15 , 20 ], paddle .float32 )
203+ init_result = paddle .nn .init .kaiming_uniform_ (
204+ parameter_meta , block = main .global_block ()
205+ )
206+ block = main .global_block ()
207+ checked_ops = self .get_init_ops_by_op_name (
208+ block , self .init_uniform_op_name
209+ )
210+ self .assertEqual (len (checked_ops ), 1 )
211+ init_op = checked_ops [0 ]
212+ limit = np .sqrt (
213+ 6.0
214+ / (
215+ init_result .shape [1 ]
216+ * init_result .shape [2 ]
217+ * init_result .shape [3 ]
218+ )
219+ )
220+
221+ min = self .get_operand_definition_op_attrs (
222+ init_op , "min" , "value"
223+ )
224+ max = self .get_operand_definition_op_attrs (
225+ init_op , "max" , "value"
226+ )
227+ self .assertAlmostEqual (min , - limit , delta = DELTA )
228+ self .assertAlmostEqual (max , limit , delta = DELTA )
229+ self .assertEqual (init_op .attrs ()['seed' ], 0 )
230+
231+ def test_kaiming_uniform_fan_out (self ):
232+ with paddle .pir_utils .IrGuard ():
233+ main = paddle .static .Program ()
234+ with paddle .static .program_guard (main , paddle .static .Program ()):
235+ parameter_meta = ParameterMeta ([5 , 10 , 15 , 20 ], paddle .float32 )
236+ init_result = paddle .nn .init .kaiming_uniform_ (
237+ parameter_meta , mode = 'fan_out' , block = main .global_block ()
238+ )
239+ block = main .global_block ()
240+ checked_ops = self .get_init_ops_by_op_name (
241+ block , self .init_uniform_op_name
242+ )
243+ self .assertEqual (len (checked_ops ), 1 )
244+ init_op = checked_ops [0 ]
245+ limit = np .sqrt (
246+ 6.0
247+ / (
248+ init_result .shape [0 ]
249+ * init_result .shape [2 ]
250+ * init_result .shape [3 ]
251+ )
252+ )
253+
254+ min = self .get_operand_definition_op_attrs (
255+ init_op , "min" , "value"
256+ )
257+ max = self .get_operand_definition_op_attrs (
258+ init_op , "max" , "value"
259+ )
260+ self .assertAlmostEqual (min , - limit , delta = DELTA )
261+ self .assertAlmostEqual (max , limit , delta = DELTA )
262+ self .assertEqual (init_op .attrs ()['seed' ], 0 )
263+
94264
95265if __name__ == '__main__' :
96266 unittest .main ()
0 commit comments