Skip to content

Commit 7806bff

Browse files
authored
【PIR API adaptor No.132】 Migrate paddle.Tensor.paddle.Tensor.logcumsumexp (#58695)
1 parent 1ad64dd commit 7806bff

File tree

2 files changed

+20
-11
lines changed

2 files changed

+20
-11
lines changed

python/paddle/tensor/math.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4203,7 +4203,7 @@ def logcumsumexp(x, axis=None, dtype=None, name=None):
42034203
if dtype is not None and x.dtype != convert_np_dtype_to_dtype_(dtype):
42044204
x = cast(x, dtype)
42054205

4206-
if in_dynamic_mode():
4206+
if in_dynamic_or_pir_mode():
42074207
if axis is None:
42084208
axis = -1
42094209
return _C_ops.logcumsumexp(x, axis, flatten, False, False)

test/legacy_test/test_logcumsumexp_op.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import paddle
2323
from paddle import base
2424
from paddle.base import core
25+
from paddle.pir_utils import test_with_pir_api
2526

2627

2728
def np_naive_logcumsumexp(x: np.ndarray, axis: Optional[int] = None):
@@ -145,7 +146,9 @@ def run_imperative(self):
145146
np.testing.assert_allclose(z, y.numpy(), rtol=1e-05)
146147

147148
def run_static(self, use_gpu=False):
148-
with base.program_guard(base.Program()):
149+
main = paddle.static.Program()
150+
startup = paddle.static.Program()
151+
with paddle.static.program_guard(main, startup):
149152
data_np = np.random.random((5, 4)).astype(np.float32)
150153
x = paddle.static.data('X', [5, 4])
151154
y = paddle.logcumsumexp(x)
@@ -156,15 +159,15 @@ def run_static(self, use_gpu=False):
156159

157160
place = base.CUDAPlace(0) if use_gpu else base.CPUPlace()
158161
exe = base.Executor(place)
159-
exe.run(base.default_startup_program())
160162
out = exe.run(
163+
main,
161164
feed={'X': data_np},
162165
fetch_list=[
163-
y.name,
164-
y2.name,
165-
y3.name,
166-
y4.name,
167-
y5.name,
166+
y,
167+
y2,
168+
y3,
169+
y4,
170+
y5,
168171
],
169172
)
170173

@@ -178,13 +181,15 @@ def run_static(self, use_gpu=False):
178181
z = np_logcumsumexp(data_np, axis=-2)
179182
np.testing.assert_allclose(z, out[4], rtol=1e-05)
180183

184+
@test_with_pir_api
181185
def test_cpu(self):
182186
paddle.disable_static(paddle.base.CPUPlace())
183187
self.run_imperative()
184188
paddle.enable_static()
185189

186190
self.run_static()
187191

192+
@test_with_pir_api
188193
def test_gpu(self):
189194
if not base.core.is_compiled_with_cuda():
190195
return
@@ -194,23 +199,26 @@ def test_gpu(self):
194199

195200
self.run_static(use_gpu=True)
196201

202+
# @test_with_pir_api
197203
def test_name(self):
198204
with base.program_guard(base.Program()):
199205
x = paddle.static.data('x', [3, 4])
200206
y = paddle.logcumsumexp(x, name='out')
201207
self.assertTrue('out' in y.name)
202208

209+
@test_with_pir_api
203210
def test_type_error(self):
204-
with base.program_guard(base.Program()):
211+
main = paddle.static.Program()
212+
startup = paddle.static.Program()
213+
with paddle.static.program_guard(main, startup):
205214
with self.assertRaises(TypeError):
206215
data_np = np.random.random((100, 100), dtype=np.int32)
207216
x = paddle.static.data('X', [100, 100], dtype='int32')
208217
y = paddle.logcumsumexp(x)
209218

210219
place = base.CUDAPlace(0)
211220
exe = base.Executor(place)
212-
exe.run(base.default_startup_program())
213-
out = exe.run(feed={'X': data_np}, fetch_list=[y.name])
221+
out = exe.run(main, feed={'X': data_np}, fetch_list=[y])
214222

215223

216224
def logcumsumexp_wrapper(
@@ -296,6 +304,7 @@ def check_main(self, x_np, dtype, axis=None):
296304
paddle.enable_static()
297305
return y_np, x_g_np
298306

307+
@test_with_pir_api
299308
def test_main(self):
300309
if not paddle.is_compiled_with_cuda():
301310
return

0 commit comments

Comments
 (0)