2222import paddle
2323from paddle import base
2424from paddle .base import core
25+ from paddle .pir_utils import test_with_pir_api
2526
2627
2728def np_naive_logcumsumexp (x : np .ndarray , axis : Optional [int ] = None ):
@@ -145,7 +146,9 @@ def run_imperative(self):
145146 np .testing .assert_allclose (z , y .numpy (), rtol = 1e-05 )
146147
147148 def run_static (self , use_gpu = False ):
148- with base .program_guard (base .Program ()):
149+ main = paddle .static .Program ()
150+ startup = paddle .static .Program ()
151+ with paddle .static .program_guard (main , startup ):
149152 data_np = np .random .random ((5 , 4 )).astype (np .float32 )
150153 x = paddle .static .data ('X' , [5 , 4 ])
151154 y = paddle .logcumsumexp (x )
@@ -156,15 +159,15 @@ def run_static(self, use_gpu=False):
156159
157160 place = base .CUDAPlace (0 ) if use_gpu else base .CPUPlace ()
158161 exe = base .Executor (place )
159- exe .run (base .default_startup_program ())
160162 out = exe .run (
163+ main ,
161164 feed = {'X' : data_np },
162165 fetch_list = [
163- y . name ,
164- y2 . name ,
165- y3 . name ,
166- y4 . name ,
167- y5 . name ,
166+ y ,
167+ y2 ,
168+ y3 ,
169+ y4 ,
170+ y5 ,
168171 ],
169172 )
170173
@@ -178,13 +181,15 @@ def run_static(self, use_gpu=False):
178181 z = np_logcumsumexp (data_np , axis = - 2 )
179182 np .testing .assert_allclose (z , out [4 ], rtol = 1e-05 )
180183
184+ @test_with_pir_api
181185 def test_cpu (self ):
182186 paddle .disable_static (paddle .base .CPUPlace ())
183187 self .run_imperative ()
184188 paddle .enable_static ()
185189
186190 self .run_static ()
187191
192+ @test_with_pir_api
188193 def test_gpu (self ):
189194 if not base .core .is_compiled_with_cuda ():
190195 return
@@ -194,23 +199,26 @@ def test_gpu(self):
194199
195200 self .run_static (use_gpu = True )
196201
202+ # @test_with_pir_api
197203 def test_name (self ):
198204 with base .program_guard (base .Program ()):
199205 x = paddle .static .data ('x' , [3 , 4 ])
200206 y = paddle .logcumsumexp (x , name = 'out' )
201207 self .assertTrue ('out' in y .name )
202208
209+ @test_with_pir_api
203210 def test_type_error (self ):
204- with base .program_guard (base .Program ()):
211+ main = paddle .static .Program ()
212+ startup = paddle .static .Program ()
213+ with paddle .static .program_guard (main , startup ):
205214 with self .assertRaises (TypeError ):
206215 data_np = np .random .random ((100 , 100 ), dtype = np .int32 )
207216 x = paddle .static .data ('X' , [100 , 100 ], dtype = 'int32' )
208217 y = paddle .logcumsumexp (x )
209218
210219 place = base .CUDAPlace (0 )
211220 exe = base .Executor (place )
212- exe .run (base .default_startup_program ())
213- out = exe .run (feed = {'X' : data_np }, fetch_list = [y .name ])
221+ out = exe .run (main , feed = {'X' : data_np }, fetch_list = [y ])
214222
215223
216224def logcumsumexp_wrapper (
@@ -296,6 +304,7 @@ def check_main(self, x_np, dtype, axis=None):
296304 paddle .enable_static ()
297305 return y_np , x_g_np
298306
307+ @test_with_pir_api
299308 def test_main (self ):
300309 if not paddle .is_compiled_with_cuda ():
301310 return
0 commit comments