1818from pass_test import PassTest
1919
2020import paddle
21+ from paddle .base import core
2122
2223paddle .enable_static ()
2324
2425
25- @unittest .skipIf (
26- not paddle .base .core .is_compiled_with_cuda (),
27- "core is not complied with CUDA" ,
28- )
2926class TestConv2dAddActFusePattern (PassTest ):
3027 r"""
3128 x_var f_var
@@ -47,10 +44,10 @@ def is_program_valid(self, program):
4744 return True
4845
4946 def build_ir_progam (self ):
50- pir_program = None
5147 with paddle .pir_utils .IrGuard ():
52- pir_program = paddle .static .Program ()
53- with paddle .pir .core .program_guard (pir_program ):
48+ main_prog = paddle .static .Program ()
49+ start_prog = paddle .static .Program ()
50+ with paddle .pir .core .program_guard (main_prog , start_prog ):
5451 x = paddle .static .data (
5552 name = 'x' , shape = [3 , 1 , 28 , 28 ], dtype = 'float32'
5653 )
@@ -67,23 +64,26 @@ def build_ir_progam(self):
6764 )
6865 act_op = paddle .nn .ReLU ()
6966 out = act_op (paddle .add (conv2d (x ), y ))
70-
71- self .pass_list = ['conv2d_add_act_fuse_pass' ]
72- self .feeds = {
73- "x" : np .random .random ((3 , 32 , 28 , 28 )).astype ("float32" ),
74- "y" : np .random .random ((3 , 32 , 28 , 28 )).astype ("float32" ),
75- }
76- self .fetch_list = [out ]
77- self .valid_op_map = {
78- "pd_op.add" : 0 ,
79- "pd_op.relu" : 0 ,
80- "pd_op.conv2d" : 0 ,
81- "pd_op.fused_conv2d_add_act" : 1 ,
82- }
83- return pir_program
67+ out = paddle . assign ( out )
68+ self .pass_list = ['conv2d_add_act_fuse_pass' ]
69+ self .feeds = {
70+ "x" : np .random .random ((3 , 1 , 28 , 28 )).astype ("float32" ),
71+ "y" : np .random .random ((3 , 32 , 28 , 28 )).astype ("float32" ),
72+ }
73+ self .fetch_list = [out ]
74+ self .valid_op_map = {
75+ "pd_op.add" : 0 ,
76+ "pd_op.relu" : 0 ,
77+ "pd_op.conv2d" : 0 ,
78+ "pd_op.fused_conv2d_add_act" : 1 ,
79+ }
80+ return [ main_prog , start_prog ]
8481
8582 def setUp (self ):
86- self .place_runtime = "gpu"
83+ if core .is_compiled_with_cuda ():
84+ self .places .append (paddle .CUDAPlace (0 ))
85+ # todo(bukejiyu): This pass will support accuracy verification in the future
86+ self .skip_accuracy_verification = True
8787
8888 def sample_program (self ):
8989 yield self .build_ir_progam (), False
@@ -92,15 +92,6 @@ def test_check_output(self):
9292 self .check_pass_correct ()
9393
9494
95- class TestConv2dAddActFusePatternWithCpu (TestConv2dAddActFusePattern ):
96- def setUp (self ):
97- self .place_runtime = "cpu"
98-
99-
100- @unittest .skipIf (
101- not paddle .base .core .is_compiled_with_cuda (),
102- "core is not complied with CUDA" ,
103- )
10495class TestConv2dAdd2ActFusePattern (PassTest ):
10596 r"""
10697 x_var f_var(persistable)
@@ -124,10 +115,10 @@ def is_program_valid(self, program):
124115 return True
125116
126117 def build_ir_progam (self ):
127- pir_program = None
128118 with paddle .pir_utils .IrGuard ():
129- pir_program = paddle .static .Program ()
130- with paddle .pir .core .program_guard (pir_program ):
119+ main_prog = paddle .static .Program ()
120+ start_prog = paddle .static .Program ()
121+ with paddle .pir .core .program_guard (main_prog , start_prog ):
131122 x = paddle .static .data (
132123 name = 'x' , shape = [3 , 1 , 28 , 28 ], dtype = 'float32'
133124 )
@@ -149,22 +140,29 @@ def build_ir_progam(self):
149140 out = act_op (
150141 paddle .add (residual_data , paddle .add (conv2d (x ), y ))
151142 )
152- self .pass_list = ['conv2d_add_act_fuse_pass' ]
153- self .feeds = {
154- "x" : np .random .random ((3 , 32 , 28 , 28 )).astype ("float32" ),
155- "y" : np .random .random ((3 , 32 , 28 , 28 )).astype ("float32" ),
156- }
157- self .fetch_list = [out ]
158- self .valid_op_map = {
159- "pd_op.add" : 0 ,
160- "pd_op.relu" : 0 ,
161- "pd_op.conv2d" : 0 ,
162- "pd_op.fused_conv2d_add_act" : 1 ,
163- }
164- return pir_program
143+ out = paddle .assign (out )
144+ self .pass_list = ['conv2d_add_act_fuse_pass' ]
145+ self .feeds = {
146+ "x" : np .random .random ((3 , 1 , 28 , 28 )).astype ("float32" ),
147+ "y" : np .random .random ((3 , 32 , 28 , 28 )).astype ("float32" ),
148+ "residual_data" : np .random .random ((3 , 32 , 28 , 28 )).astype (
149+ "float32"
150+ ),
151+ }
152+ self .fetch_list = [out ]
153+ self .valid_op_map = {
154+ "pd_op.add" : 0 ,
155+ "pd_op.relu" : 0 ,
156+ "pd_op.conv2d" : 0 ,
157+ "pd_op.fused_conv2d_add_act" : 1 ,
158+ }
159+ return [main_prog , start_prog ]
165160
166161 def setUp (self ):
167- self .place_runtime = "gpu"
162+ if core .is_compiled_with_cuda ():
163+ self .places .append (paddle .CUDAPlace (0 ))
164+ # todo(bukejiyu): This pass will support accuracy verification in the future
165+ self .skip_accuracy_verification = True
168166
169167 def sample_program (self ):
170168 yield self .build_ir_progam (), False
@@ -173,10 +171,5 @@ def test_check_output(self):
173171 self .check_pass_correct ()
174172
175173
176- class TestConv2dAdd2ActFusePatternWithCpu (TestConv2dAdd2ActFusePattern ):
177- def setUp (self ):
178- self .place_runtime = "cpu"
179-
180-
181174if __name__ == "__main__" :
182175 unittest .main ()
0 commit comments