Skip to content

Commit 326352b

Browse files
authored
【PIR API adaptor No.59, 61, 63】Migrate some ops into pir (#58697)
* migrate dist, eigh and eigvalsh into pir * fix bug * fix bug
1 parent f1de995 commit 326352b

4 files changed

Lines changed: 17 additions & 13 deletions

File tree

python/paddle/tensor/linalg.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -717,7 +717,7 @@ def dist(x, y, p=2, name=None):
717717
Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
718718
0.)
719719
"""
720-
if in_dynamic_mode():
720+
if in_dynamic_or_pir_mode():
721721
return _C_ops.dist(x, y, p)
722722

723723
check_variable_and_dtype(
@@ -2818,7 +2818,7 @@ def eigh(x, UPLO='L', name=None):
28182818
[ 0.3826833963394165j , -0.9238795042037964j ]])
28192819
28202820
"""
2821-
if in_dynamic_mode():
2821+
if in_dynamic_or_pir_mode():
28222822
return _C_ops.eigh(x, UPLO)
28232823
else:
28242824

@@ -3324,7 +3324,7 @@ def eigvalsh(x, UPLO='L', name=None):
33243324
Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True,
33253325
[0.17157286, 5.82842731])
33263326
"""
3327-
if in_dynamic_mode():
3327+
if in_dynamic_or_pir_mode():
33283328
values, _ = _C_ops.eigvalsh(x, UPLO, x.stop_gradient)
33293329
return values
33303330
else:

test/legacy_test/test_dist_op.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import paddle
2121
from paddle import base
2222
from paddle.base import core
23+
from paddle.pir_utils import test_with_pir_api
2324

2425
paddle.enable_static()
2526

@@ -113,13 +114,11 @@ def get_reduce_dims(x, y):
113114
return x_grad, y_grad
114115

115116
def test_check_output(self):
116-
self.check_output()
117+
self.check_output(check_pir=True)
117118

118119
def test_check_grad(self):
119120
self.check_grad(
120-
["X", "Y"],
121-
"Out",
122-
user_defined_grads=self.gradient,
121+
["X", "Y"], "Out", user_defined_grads=self.gradient, check_pir=True
123122
)
124123

125124

@@ -244,10 +243,11 @@ def init_data_type(self):
244243
'float32' if core.is_compiled_with_rocm() else 'float64'
245244
)
246245

246+
@test_with_pir_api
247247
def test_api(self):
248248
self.init_data_type()
249-
main_program = base.Program()
250-
startup_program = base.Program()
249+
main_program = paddle.static.Program()
250+
startup_program = paddle.static.Program()
251251
with base.program_guard(main_program, startup_program):
252252
x = paddle.static.data(
253253
name='x', shape=[2, 3, 4, 5], dtype=self.data_type
@@ -266,7 +266,7 @@ def test_api(self):
266266
)
267267
exe = base.Executor(place)
268268
out = exe.run(
269-
base.default_main_program(),
269+
main_program,
270270
feed={'x': x_i, 'y': y_i},
271271
fetch_list=[result],
272272
)

test/legacy_test/test_eigh_op.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from op_test import OpTest
1919

2020
import paddle
21+
from paddle.pir_utils import test_with_pir_api
2122

2223

2324
def valid_eigh_result(A, eigh_value, eigh_vector, uplo):
@@ -92,7 +93,7 @@ def init_input(self):
9293
# self.check_output(no_check_set=['Eigenvectors'])
9394

9495
def test_grad(self):
95-
self.check_grad(["X"], ["Eigenvalues"])
96+
self.check_grad(["X"], ["Eigenvalues"], check_pir=True)
9697

9798

9899
class TestEighUPLOCase(TestEighOp):
@@ -183,6 +184,7 @@ def check_static_complex_result(self):
183184
)
184185
valid_eigh_result(self.complex_symm, actual_w, actual_v, self.UPLO)
185186

187+
@test_with_pir_api
186188
def test_in_static_mode(self):
187189
paddle.enable_static()
188190
self.check_static_float_result()

test/legacy_test/test_eigvalsh_op.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from op_test import OpTest
1919

2020
import paddle
21+
from paddle.pir_utils import test_with_pir_api
2122

2223

2324
def compare_result(actual, expected):
@@ -72,10 +73,10 @@ def init_input(self):
7273

7374
def test_check_output(self):
7475
# Vectors in posetive or negative is equivalent
75-
self.check_output(no_check_set=['Eigenvectors'])
76+
self.check_output(no_check_set=['Eigenvectors'], check_pir=True)
7677

7778
def test_grad(self):
78-
self.check_grad(["X"], ["Eigenvalues"])
79+
self.check_grad(["X"], ["Eigenvalues"], check_pir=True)
7980

8081

8182
class TestEigvalshUPLOCase(TestEigvalshOp):
@@ -166,6 +167,7 @@ def check_static_complex_result(self):
166167
expected_w = np.linalg.eigvalsh(self.complex_symm)
167168
compare_result(actual_w[0], expected_w)
168169

170+
@test_with_pir_api
169171
def test_in_static_mode(self):
170172
paddle.enable_static()
171173
self.check_static_float_result()

0 commit comments

Comments
 (0)