Skip to content

Commit 5b8a1a5

Browse files
update unit test for kaiming_uniform_
1 parent 87d9313 commit 5b8a1a5

File tree

1 file changed

+181
-11
lines changed

1 file changed

+181
-11
lines changed

test/legacy_test/test_nn_init_function.py

Lines changed: 181 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,12 @@
2121

2222
import paddle
2323
from paddle import nn
24+
from paddle.pir.core import ParameterMeta
2425

26+
DELTA = 0.00001
2527

26-
def get_uniform_min_and_max(weight):
27-
min_value = np.min(weight)
28-
max_value = np.max(weight)
29-
return min_value, max_value
3028

31-
32-
class TestKaimingUniform(unittest.TestCase):
29+
class TestKaimingUniformFunc(unittest.TestCase):
3330
def _test_kaiming_uniform_common(self, tensor):
3431
init = paddle.nn.init.kaiming_uniform_
3532
init(tensor, a=0, mode="fan_in", nonlinearity="leaky_relu")
@@ -54,12 +51,54 @@ def _is_uniform(self, tensor, a, b):
5451
def _random_float(self, a, b):
5552
return (b - a) * random.random() + a
5653

54+
def calculate_gain(self, nonlinearity, param):
55+
recommended_gain = {
56+
'sigmoid': 1,
57+
'linear': 1,
58+
'conv1d': 1,
59+
'conv2d': 1,
60+
'conv3d': 1,
61+
'conv1d_transpose': 1,
62+
'conv_transpose1d': 1,
63+
'conv2d_transpose': 1,
64+
'conv_transpose2d': 1,
65+
'conv3d_transpose': 1,
66+
'conv_transpose3d': 1,
67+
'tanh': 5.0 / 3,
68+
'relu': math.sqrt(2.0),
69+
'leaky_relu': math.sqrt(2.0 / (1 + param**2)),
70+
'selu': 3.0 / 4,
71+
}
72+
return recommended_gain[nonlinearity]
73+
74+
def test_kaiming_uniform_nonlinearity(self):
75+
for nonlinearity in [
76+
'conv_transpose1d',
77+
'conv_transpose2d',
78+
'conv_transpose3d',
79+
'relu',
80+
'leaky_relu',
81+
]:
82+
input_tensor = paddle.zeros([1024, 512])
83+
paddle.nn.init.kaiming_uniform_(
84+
input_tensor, nonlinearity=nonlinearity
85+
)
86+
87+
fan_in = input_tensor.shape[0]
88+
89+
expected_std = self.calculate_gain(
90+
nonlinearity=nonlinearity, param=0
91+
)
92+
93+
bounds = expected_std * math.sqrt(3.0 / float(fan_in))
94+
assert self._is_uniform(input_tensor, -bounds, bounds)
95+
5796
def test_kaiming_uniform(self):
5897
for use_a in [True, False]:
59-
for dims in [2, 4]:
98+
for dims in [2, 3, 4]:
6099
for mode in ["fan_in", "fan_out"]:
61100
input_tensor = self._create_random_nd_tensor(
62-
dims, size_min=20, size_max=25
101+
dims, size_min=20, size_max=108
63102
)
64103
if use_a:
65104
a = self._random_float(0.1, 2)
@@ -86,11 +125,142 @@ def test_kaiming_uniform(self):
86125
n = fan_in
87126
else:
88127
n = fan_out
89-
90-
expected_std = math.sqrt(2.0 / ((1 + a**2) * n))
91-
bounds = expected_std * math.sqrt(3.0)
128+
expected_std = self.calculate_gain(
129+
nonlinearity='leaky_relu', param=a
130+
)
131+
bounds = expected_std * math.sqrt(3.0 / float(n))
92132
assert self._is_uniform(input_tensor, -bounds, bounds)
93133

134+
@unittest.skipIf(
135+
not paddle.is_compiled_with_cuda(), "core is not compiled with CUDA"
136+
)
137+
def test_kaiming_uniform_fp16(self):
138+
input_tensor = paddle.zeros([1024, 512], dtype='float16')
139+
paddle.nn.init.kaiming_uniform_(input_tensor)
140+
fan_in = input_tensor.shape[0]
141+
142+
expected_std = self.calculate_gain(nonlinearity='leaky_relu', param=0)
143+
144+
bounds = expected_std * math.sqrt(3.0 / float(fan_in))
145+
assert self._is_uniform(input_tensor, -bounds, bounds)
146+
assert input_tensor.dtype == paddle.float16
147+
148+
149+
class TestKaimingUniformFuncPir(unittest.TestCase):
150+
def setUp(self):
151+
self.init_uniform_op_name = 'pd_op.uniform'
152+
153+
def get_operand_definition_op_attrs(self, cur_op, operand_name, attr_name):
154+
input_names = cur_op.get_input_names()
155+
self.assertIn(operand_name, input_names)
156+
attr = (
157+
cur_op.operand(input_names.index(operand_name))
158+
.source()
159+
.get_defining_op()
160+
.attrs()[attr_name]
161+
)
162+
return attr
163+
164+
def get_init_ops_by_op_name(self, block, op_name):
165+
checked_ops = []
166+
for op in block.ops:
167+
# get init op
168+
if op_name == op.name():
169+
checked_ops.append(op)
170+
return checked_ops
171+
172+
def test_kaiming_uniform_(self):
173+
with paddle.pir_utils.IrGuard():
174+
main = paddle.static.Program()
175+
with paddle.static.program_guard(main, paddle.static.Program()):
176+
parameter_meta = ParameterMeta([1024, 512], paddle.float32)
177+
init_result = paddle.nn.init.kaiming_uniform_(
178+
parameter_meta, block=main.global_block()
179+
)
180+
block = main.global_block()
181+
checked_ops = self.get_init_ops_by_op_name(
182+
block, self.init_uniform_op_name
183+
)
184+
self.assertEqual(len(checked_ops), 1)
185+
init_op = checked_ops[0]
186+
limit = np.sqrt(6.0 / init_result.shape[0])
187+
188+
min = self.get_operand_definition_op_attrs(
189+
init_op, "min", "value"
190+
)
191+
max = self.get_operand_definition_op_attrs(
192+
init_op, "max", "value"
193+
)
194+
self.assertAlmostEqual(min, -limit, delta=DELTA)
195+
self.assertAlmostEqual(max, limit, delta=DELTA)
196+
self.assertEqual(init_op.attrs()['seed'], 0)
197+
198+
def test_kaiming_uniform_conv(self):
199+
with paddle.pir_utils.IrGuard():
200+
main = paddle.static.Program()
201+
with paddle.static.program_guard(main, paddle.static.Program()):
202+
parameter_meta = ParameterMeta([5, 10, 15, 20], paddle.float32)
203+
init_result = paddle.nn.init.kaiming_uniform_(
204+
parameter_meta, block=main.global_block()
205+
)
206+
block = main.global_block()
207+
checked_ops = self.get_init_ops_by_op_name(
208+
block, self.init_uniform_op_name
209+
)
210+
self.assertEqual(len(checked_ops), 1)
211+
init_op = checked_ops[0]
212+
limit = np.sqrt(
213+
6.0
214+
/ (
215+
init_result.shape[1]
216+
* init_result.shape[2]
217+
* init_result.shape[3]
218+
)
219+
)
220+
221+
min = self.get_operand_definition_op_attrs(
222+
init_op, "min", "value"
223+
)
224+
max = self.get_operand_definition_op_attrs(
225+
init_op, "max", "value"
226+
)
227+
self.assertAlmostEqual(min, -limit, delta=DELTA)
228+
self.assertAlmostEqual(max, limit, delta=DELTA)
229+
self.assertEqual(init_op.attrs()['seed'], 0)
230+
231+
def test_kaiming_uniform_fan_out(self):
232+
with paddle.pir_utils.IrGuard():
233+
main = paddle.static.Program()
234+
with paddle.static.program_guard(main, paddle.static.Program()):
235+
parameter_meta = ParameterMeta([5, 10, 15, 20], paddle.float32)
236+
init_result = paddle.nn.init.kaiming_uniform_(
237+
parameter_meta, mode='fan_out', block=main.global_block()
238+
)
239+
block = main.global_block()
240+
checked_ops = self.get_init_ops_by_op_name(
241+
block, self.init_uniform_op_name
242+
)
243+
self.assertEqual(len(checked_ops), 1)
244+
init_op = checked_ops[0]
245+
limit = np.sqrt(
246+
6.0
247+
/ (
248+
init_result.shape[0]
249+
* init_result.shape[2]
250+
* init_result.shape[3]
251+
)
252+
)
253+
254+
min = self.get_operand_definition_op_attrs(
255+
init_op, "min", "value"
256+
)
257+
max = self.get_operand_definition_op_attrs(
258+
init_op, "max", "value"
259+
)
260+
self.assertAlmostEqual(min, -limit, delta=DELTA)
261+
self.assertAlmostEqual(max, limit, delta=DELTA)
262+
self.assertEqual(init_op.attrs()['seed'], 0)
263+
94264

95265
if __name__ == '__main__':
96266
unittest.main()

0 commit comments

Comments
 (0)