1919
2020import paddle
2121from paddle import base
22- from paddle .base import Program , core , program_guard
22+ from paddle .base import core
23+ from paddle .pir_utils import test_with_pir_api
2324
2425
2526class TestRollOp (OpTest ):
@@ -48,10 +49,10 @@ def init_dtype_type(self):
4849 self .axis = [0 , - 2 ]
4950
5051 def test_check_output (self ):
51- self .check_output (check_prim = True )
52+ self .check_output (check_prim = True , check_pir = True )
5253
5354 def test_check_grad_normal (self ):
54- self .check_grad (['X' ], 'Out' , check_prim = True )
55+ self .check_grad (['X' ], 'Out' , check_prim = True , check_pir = True )
5556
5657
5758class TestRollOpCase2 (TestRollOp ):
@@ -108,10 +109,14 @@ def init_dtype_type(self):
108109 self .place = core .CUDAPlace (0 )
109110
110111 def test_check_output (self ):
111- self .check_output_with_place (self .place , check_prim = True )
112+ self .check_output_with_place (
113+ self .place , check_prim = True , check_pir = True
114+ )
112115
113116 def test_check_grad_normal (self ):
114- self .check_grad_with_place (self .place , ['X' ], 'Out' , check_prim = True )
117+ self .check_grad_with_place (
118+ self .place , ['X' ], 'Out' , check_prim = True , check_pir = True
119+ )
115120
116121
117122@unittest .skipIf (
@@ -128,10 +133,14 @@ def init_dtype_type(self):
128133 self .place = core .CUDAPlace (0 )
129134
130135 def test_check_output (self ):
131- self .check_output_with_place (self .place , check_prim = True )
136+ self .check_output_with_place (
137+ self .place , check_prim = True , check_pir = True
138+ )
132139
133140 def test_check_grad_normal (self ):
134- self .check_grad_with_place (self .place , ['X' ], 'Out' , check_prim = True )
141+ self .check_grad_with_place (
142+ self .place , ['X' ], 'Out' , check_prim = True , check_pir = True
143+ )
135144
136145
137146@unittest .skipIf (
@@ -148,10 +157,14 @@ def init_dtype_type(self):
148157 self .place = core .CUDAPlace (0 )
149158
150159 def test_check_output (self ):
151- self .check_output_with_place (self .place , check_prim = True )
160+ self .check_output_with_place (
161+ self .place , check_prim = True , check_pir = True
162+ )
152163
153164 def test_check_grad_normal (self ):
154- self .check_grad_with_place (self .place , ['X' ], 'Out' , check_prim = True )
165+ self .check_grad_with_place (
166+ self .place , ['X' ], 'Out' , check_prim = True , check_pir = True
167+ )
155168
156169
157170class TestRollAPI (unittest .TestCase ):
@@ -160,37 +173,53 @@ def input_data(self):
160173 [[1.0 , 2.0 , 3.0 ], [4.0 , 5.0 , 6.0 ], [7.0 , 8.0 , 9.0 ]]
161174 )
162175
163- def test_roll_op_api (self ):
164- self .input_data ()
165-
176+ @test_with_pir_api
177+ def test_roll_op_api_case1 (self ):
166178 paddle .enable_static ()
167- # case 1:
168- with program_guard (Program (), Program ()):
179+ with paddle .static .program_guard (
180+ paddle .static .Program (), paddle .static .Program ()
181+ ):
169182 x = paddle .static .data (name = 'x' , shape = [- 1 , 3 ], dtype = 'float32' )
170- x .desc .set_need_check_feed (False )
183+ data_x = np .array (
184+ [[1.0 , 2.0 , 3.0 ], [4.0 , 5.0 , 6.0 ], [7.0 , 8.0 , 9.0 ]]
185+ ).astype ('float32' )
171186 z = paddle .roll (x , shifts = 1 )
172- exe = base . Executor (base .CPUPlace ())
187+ exe = paddle . static . Executor (paddle .CPUPlace ())
173188 (res ,) = exe .run (
174- feed = {'x' : self .data_x }, fetch_list = [z .name ], return_numpy = False
189+ paddle .static .default_main_program (),
190+ feed = {'x' : data_x },
191+ fetch_list = [z ],
192+ return_numpy = False ,
175193 )
176194 expect_out = np .array (
177195 [[9.0 , 1.0 , 2.0 ], [3.0 , 4.0 , 5.0 ], [6.0 , 7.0 , 8.0 ]]
178196 )
179- np .testing .assert_allclose (expect_out , np .array (res ), rtol = 1e-05 )
197+ np .testing .assert_allclose (expect_out , np .array (res ), rtol = 1e-05 )
198+ paddle .disable_static ()
180199
181- # case 2:
182- with program_guard (Program (), Program ()):
200+ @test_with_pir_api
201+ def test_roll_op_api_case2 (self ):
202+ paddle .enable_static ()
203+ with paddle .static .program_guard (
204+ paddle .static .Program (), paddle .static .Program ()
205+ ):
183206 x = paddle .static .data (name = 'x' , shape = [- 1 , 3 ], dtype = 'float32' )
184- x .desc .set_need_check_feed (False )
207+ data_x = np .array (
208+ [[1.0 , 2.0 , 3.0 ], [4.0 , 5.0 , 6.0 ], [7.0 , 8.0 , 9.0 ]]
209+ ).astype ('float32' )
185210 z = paddle .roll (x , shifts = 1 , axis = 0 )
186- exe = base . Executor (base .CPUPlace ())
211+ exe = paddle . static . Executor (paddle .CPUPlace ())
187212 (res ,) = exe .run (
188- feed = {'x' : self .data_x }, fetch_list = [z .name ], return_numpy = False
213+ paddle .static .default_main_program (),
214+ feed = {'x' : data_x },
215+ fetch_list = [z ],
216+ return_numpy = False ,
217+ )
218+ expect_out = np .array (
219+ [[7.0 , 8.0 , 9.0 ], [1.0 , 2.0 , 3.0 ], [4.0 , 5.0 , 6.0 ]]
189220 )
190- expect_out = np .array (
191- [[7.0 , 8.0 , 9.0 ], [1.0 , 2.0 , 3.0 ], [4.0 , 5.0 , 6.0 ]]
192- )
193221 np .testing .assert_allclose (expect_out , np .array (res ), rtol = 1e-05 )
222+ paddle .disable_static ()
194223
195224 def test_dygraph_api (self ):
196225 self .input_data ()
@@ -214,22 +243,27 @@ def test_dygraph_api(self):
214243 )
215244 np .testing .assert_allclose (expect_out , np_z , rtol = 1e-05 )
216245
246+ @test_with_pir_api
217247 def test_roll_op_false (self ):
218- self .input_data ()
219-
220248 def test_axis_out_range ():
221- with program_guard (Program (), Program ()):
249+ paddle .enable_static ()
250+ with paddle .static .program_guard (
251+ paddle .static .Program (), paddle .static .Program ()
252+ ):
222253 x = paddle .static .data (name = 'x' , shape = [- 1 , 3 ], dtype = 'float32' )
223- x .desc .set_need_check_feed (False )
254+ data_x = np .array (
255+ [[1.0 , 2.0 , 3.0 ], [4.0 , 5.0 , 6.0 ], [7.0 , 8.0 , 9.0 ]]
256+ ).astype ('float32' )
224257 z = paddle .roll (x , shifts = 1 , axis = 10 )
225258 exe = base .Executor (base .CPUPlace ())
226259 (res ,) = exe .run (
227- feed = {'x' : self . data_x },
228- fetch_list = [z . name ],
260+ feed = {'x' : data_x },
261+ fetch_list = [z ],
229262 return_numpy = False ,
230263 )
231264
232265 self .assertRaises (ValueError , test_axis_out_range )
266+ paddle .disable_static ()
233267
234268 def test_shifts_as_tensor_dygraph (self ):
235269 with base .dygraph .guard ():
@@ -241,23 +275,28 @@ def test_shifts_as_tensor_dygraph(self):
241275 expected_out = np .array ([[8 , 6 , 7 ], [2 , 0 , 1 ], [5 , 3 , 4 ]])
242276 np .testing .assert_allclose (out , expected_out , rtol = 1e-05 )
243277
278+ @test_with_pir_api
244279 def test_shifts_as_tensor_static (self ):
245- with program_guard (Program (), Program ()):
280+ paddle .enable_static ()
281+ with paddle .static .program_guard (
282+ paddle .static .Program (), paddle .static .Program ()
283+ ):
246284 x = paddle .arange (9 ).reshape ([3 , 3 ]).astype ('float32' )
247285 shape = paddle .shape (x )
248286 shifts = shape // 2
249287 axes = [0 , 1 ]
250288 out = paddle .roll (x , shifts = shifts , axis = axes )
251289 expected_out = np .array ([[8 , 6 , 7 ], [2 , 0 , 1 ], [5 , 3 , 4 ]])
252290
253- exe = base . Executor (base .CPUPlace ())
291+ exe = paddle . static . Executor (paddle .CPUPlace ())
254292 [out_np ] = exe .run (fetch_list = [out ])
255293 np .testing .assert_allclose (out_np , expected_out , rtol = 1e-05 )
256294
257295 if paddle .is_compiled_with_cuda ():
258296 exe = base .Executor (base .CPUPlace ())
259297 [out_np ] = exe .run (fetch_list = [out ])
260298 np .testing .assert_allclose (out_np , expected_out , rtol = 1e-05 )
299+ paddle .disable_static ()
261300
262301
263302if __name__ == "__main__" :
0 commit comments