1919
2020import paddle
2121import paddle .nn .functional as F
22- from paddle import base
2322from paddle .base import core
23+ from paddle .pir_utils import test_with_pir_api
2424
2525
2626def channel_shuffle_np (x , groups , data_format = "NCHW" ):
@@ -71,10 +71,10 @@ def init_data_format(self):
7171 self .format = "NCHW"
7272
7373 def test_check_output (self ):
74- self .check_output ()
74+ self .check_output (check_pir = True )
7575
7676 def test_check_grad (self ):
77- self .check_grad (['X' ], 'Out' )
77+ self .check_grad (['X' ], 'Out' , check_pir = True )
7878
7979
8080class TestChannelLast (TestChannelShuffleOp ):
@@ -84,84 +84,122 @@ def init_data_format(self):
8484
8585class TestChannelShuffleAPI (unittest .TestCase ):
8686 def setUp (self ):
87- self .x_1_np = np .random .random ([2 , 9 , 4 , 4 ]).astype ("float64" )
8887 self .x_2_np = np .random .random ([2 , 4 , 4 , 9 ]).astype ("float64" )
89- self .out_1_np = channel_shuffle_np (self .x_1_np , 3 )
9088 self .out_2_np = channel_shuffle_np (self .x_2_np , 3 , "NHWC" )
89+ self .x_1_np = np .random .random ([2 , 9 , 4 , 4 ]).astype ("float64" )
90+ self .out_1_np = channel_shuffle_np (self .x_1_np , 3 )
9191
92+ @test_with_pir_api
9293 def test_static_graph_functional (self ):
93- for use_cuda in (
94- [ False , True ] if core . is_compiled_with_cuda () else [ False ]
94+ with paddle . static . program_guard (
95+ paddle . static . Program (), paddle . static . Program ()
9596 ):
96- place = paddle .CUDAPlace (0 ) if use_cuda else paddle .CPUPlace ()
97-
98- paddle .enable_static ()
99- x_1 = paddle .static .data (
100- name = "x" , shape = [2 , 9 , 4 , 4 ], dtype = "float64"
101- )
102- x_2 = paddle .static .data (
103- name = "x2" , shape = [2 , 4 , 4 , 9 ], dtype = "float64"
104- )
105- out_1 = F .channel_shuffle (x_1 , 3 )
106- out_2 = F .channel_shuffle (x_2 , 3 , "NHWC" )
107-
108- exe = paddle .static .Executor (place = place )
109- res_1 = exe .run (
110- base .default_main_program (),
111- feed = {"x" : self .x_1_np },
112- fetch_list = out_1 ,
113- use_prune = True ,
114- )
115-
116- res_2 = exe .run (
117- base .default_main_program (),
118- feed = {"x2" : self .x_2_np },
119- fetch_list = out_2 ,
120- use_prune = True ,
121- )
97+ for use_cuda in (
98+ [False , True ] if core .is_compiled_with_cuda () else [False ]
99+ ):
100+ place = paddle .CUDAPlace (0 ) if use_cuda else paddle .CPUPlace ()
101+
102+ paddle .enable_static ()
103+ x_1 = paddle .static .data (
104+ name = "x" , shape = [2 , 9 , 4 , 4 ], dtype = "float64"
105+ )
106+ out_1 = F .channel_shuffle (x_1 , 3 )
107+
108+ exe = paddle .static .Executor (place = place )
109+ res_1 = exe .run (
110+ paddle .static .default_main_program (),
111+ feed = {"x" : self .x_1_np },
112+ fetch_list = out_1 ,
113+ use_prune = True ,
114+ )
122115
123- np .testing .assert_allclose (res_1 [0 ], self .out_1_np )
124- np .testing .assert_allclose (res_2 [0 ], self .out_2_np )
116+ np .testing .assert_allclose (res_1 [0 ], self .out_1_np )
125117
126118 # same test between layer and functional in this op.
119+ @test_with_pir_api
127120 def test_static_graph_layer (self ):
128- for use_cuda in (
129- [ False , True ] if core . is_compiled_with_cuda () else [ False ]
121+ with paddle . static . program_guard (
122+ paddle . static . Program (), paddle . static . Program ()
130123 ):
131- place = paddle .CUDAPlace (0 ) if use_cuda else paddle .CPUPlace ()
124+ for use_cuda in (
125+ [False , True ] if core .is_compiled_with_cuda () else [False ]
126+ ):
127+ place = paddle .CUDAPlace (0 ) if use_cuda else paddle .CPUPlace ()
128+
129+ paddle .enable_static ()
130+ x_1 = paddle .static .data (
131+ name = "x" , shape = [2 , 9 , 4 , 4 ], dtype = "float64"
132+ )
133+ # init instance
134+ ps_1 = paddle .nn .ChannelShuffle (3 )
135+ out_1 = ps_1 (x_1 )
136+ out_1_np = channel_shuffle_np (self .x_1_np , 3 )
137+
138+ exe = paddle .static .Executor (place = place )
139+ res_1 = exe .run (
140+ paddle .static .default_main_program (),
141+ feed = {"x" : self .x_1_np },
142+ fetch_list = out_1 ,
143+ use_prune = True ,
144+ )
132145
133- paddle .enable_static ()
134- x_1 = paddle .static .data (
135- name = "x" , shape = [2 , 9 , 4 , 4 ], dtype = "float64"
136- )
137- x_2 = paddle .static .data (
138- name = "x2" , shape = [2 , 4 , 4 , 9 ], dtype = "float64"
139- )
140- # init instance
141- ps_1 = paddle .nn .ChannelShuffle (3 )
142- ps_2 = paddle .nn .ChannelShuffle (3 , "NHWC" )
143- out_1 = ps_1 (x_1 )
144- out_2 = ps_2 (x_2 )
145- out_1_np = channel_shuffle_np (self .x_1_np , 3 )
146- out_2_np = channel_shuffle_np (self .x_2_np , 3 , "NHWC" )
147-
148- exe = paddle .static .Executor (place = place )
149- res_1 = exe .run (
150- base .default_main_program (),
151- feed = {"x" : self .x_1_np },
152- fetch_list = out_1 ,
153- use_prune = True ,
154- )
146+ np .testing .assert_allclose (res_1 [0 ], out_1_np )
155147
156- res_2 = exe .run (
157- base .default_main_program (),
158- feed = {"x2" : self .x_2_np },
159- fetch_list = out_2 ,
160- use_prune = True ,
161- )
148+ @test_with_pir_api
149+ def test_static_graph_functional_new (self ):
150+ with paddle .static .program_guard (
151+ paddle .static .Program (), paddle .static .Program ()
152+ ):
153+ for use_cuda in (
154+ [False , True ] if core .is_compiled_with_cuda () else [False ]
155+ ):
156+ place = paddle .CUDAPlace (0 ) if use_cuda else paddle .CPUPlace ()
157+
158+ paddle .enable_static ()
159+ x_2 = paddle .static .data (
160+ name = "x2" , shape = [2 , 4 , 4 , 9 ], dtype = "float64"
161+ )
162+ out_2 = F .channel_shuffle (x_2 , 3 , "NHWC" )
163+
164+ exe = paddle .static .Executor (place = place )
165+ res_2 = exe .run (
166+ paddle .static .default_main_program (),
167+ feed = {"x2" : self .x_2_np },
168+ fetch_list = out_2 ,
169+ use_prune = True ,
170+ )
171+
172+ np .testing .assert_allclose (res_2 [0 ], self .out_2_np )
173+
174+ @test_with_pir_api
175+ def test_static_graph_layer_new (self ):
176+ with paddle .static .program_guard (
177+ paddle .static .Program (), paddle .static .Program ()
178+ ):
179+ for use_cuda in (
180+ [False , True ] if core .is_compiled_with_cuda () else [False ]
181+ ):
182+ place = paddle .CUDAPlace (0 ) if use_cuda else paddle .CPUPlace ()
183+
184+ paddle .enable_static ()
185+ x_2 = paddle .static .data (
186+ name = "x2" , shape = [2 , 4 , 4 , 9 ], dtype = "float64"
187+ )
188+ # init instance
189+ ps_2 = paddle .nn .ChannelShuffle (3 , "NHWC" )
190+ out_2 = ps_2 (x_2 )
191+ out_2_np = channel_shuffle_np (self .x_2_np , 3 , "NHWC" )
192+
193+ exe = paddle .static .Executor (place = place )
194+
195+ res_2 = exe .run (
196+ paddle .static .default_main_program (),
197+ feed = {"x2" : self .x_2_np },
198+ fetch_list = out_2 ,
199+ use_prune = True ,
200+ )
162201
163- np .testing .assert_allclose (res_1 [0 ], out_1_np )
164- np .testing .assert_allclose (res_2 [0 ], out_2_np )
202+ np .testing .assert_allclose (res_2 [0 ], out_2_np )
165203
166204 def run_dygraph (self , groups , data_format ):
167205 n , c , h , w = 2 , 9 , 4 , 4
@@ -209,6 +247,7 @@ def test_dygraph2(self):
209247
210248
211249class TestChannelShuffleError (unittest .TestCase ):
250+ @test_with_pir_api
212251 def test_error_functional (self ):
213252 def error_input ():
214253 with paddle .base .dygraph .guard ():
@@ -240,6 +279,7 @@ def error_data_format():
240279
241280 self .assertRaises (ValueError , error_data_format )
242281
282+ @test_with_pir_api
243283 def test_error_layer (self ):
244284 def error_input_layer ():
245285 with paddle .base .dygraph .guard ():
@@ -308,15 +348,11 @@ def init_data_format(self):
308348
309349 def test_check_output (self ):
310350 place = core .CUDAPlace (0 )
311- self .check_output_with_place (place )
351+ self .check_output_with_place (place , check_pir = True )
312352
313353 def test_check_grad (self ):
314354 place = core .CUDAPlace (0 )
315- self .check_grad_with_place (
316- place ,
317- ['X' ],
318- 'Out' ,
319- )
355+ self .check_grad_with_place (place , ['X' ], 'Out' , check_pir = True )
320356
321357
322358if __name__ == '__main__' :
0 commit comments