1- # Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
1+ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
22#
33# Licensed under the Apache License, Version 2.0 (the "License");
44# you may not use this file except in compliance with the License.
@@ -48,11 +48,11 @@ def test_check_grad(self):
4848 core .CPUPlace (), ["X" ],
4949 "Out" ,
5050 check_dygraph = False ,
51- user_defined_grads = [self .x_bf16 ],
52- user_defined_grad_outputs = [self .x_fp32 ])
51+ user_defined_grads = [self .inputs [ 'X' ] ],
52+ user_defined_grad_outputs = [self .outputs [ 'Out' ] ])
5353
5454
55- class TestCastFP32ToBF16MKLDNNOp (OpTest ):
55+ class TestCastFP32ToBF16MKLDNNOp (TestCastBF16ToFP32MKLDNNOp ):
5656 def setUp (self ):
5757 self .x_fp32 = np .random .random (size = [2 , 6 ]).astype ("float32" )
5858 self .x_bf16 = convert_float_to_uint16 (self .x_fp32 )
@@ -66,19 +66,8 @@ def setUp(self):
6666 }
6767 self .op_type = 'cast'
6868
69- def test_check_output (self ):
70- self .check_output (check_dygraph = False )
71-
72- def test_check_grad (self ):
73- self .check_grad_with_place (
74- core .CPUPlace (), ["X" ],
75- "Out" ,
76- check_dygraph = False ,
77- user_defined_grads = [self .x_fp32 ],
78- user_defined_grad_outputs = [self .x_bf16 ])
79-
8069
81- class TestCastBF16ToBF16MKLDNNOp (OpTest ):
70+ class TestCastBF16ToBF16MKLDNNOp (TestCastBF16ToFP32MKLDNNOp ):
8271 def setUp (self ):
8372 self .x_fp32 = np .random .random (size = [6 , 13 ]).astype ("float32" )
8473 self .x_bf16 = convert_float_to_uint16 (self .x_fp32 )
@@ -92,19 +81,8 @@ def setUp(self):
9281 }
9382 self .op_type = 'cast'
9483
95- def test_check_output (self ):
96- self .check_output (check_dygraph = False )
9784
98- def test_check_grad (self ):
99- self .check_grad_with_place (
100- core .CPUPlace (), ["X" ],
101- "Out" ,
102- check_dygraph = False ,
103- user_defined_grads = [self .x_bf16 ],
104- user_defined_grad_outputs = [self .x_bf16 ])
105-
106-
107- class TestCastFP32ToFP32MKLDNNOp (OpTest ):
85+ class TestCastFP32ToFP32MKLDNNOp (TestCastBF16ToFP32MKLDNNOp ):
10886 def setUp (self ):
10987 self .x_fp32 = np .random .random (size = [7 , 15 ]).astype ("float32" )
11088
@@ -117,17 +95,6 @@ def setUp(self):
11795 }
11896 self .op_type = 'cast'
11997
120- def test_check_output (self ):
121- self .check_output (check_dygraph = False )
122-
123- def test_check_grad (self ):
124- self .check_grad_with_place (
125- core .CPUPlace (), ["X" ],
126- "Out" ,
127- check_dygraph = False ,
128- user_defined_grads = [self .x_fp32 ],
129- user_defined_grad_outputs = [self .x_fp32 ])
130-
13198
13299if __name__ == '__main__' :
133100 paddle .enable_static ()
0 commit comments