Skip to content

Commit aac8b6d

Browse files
authored
【PIR API adaptor No.156、159、180、189】Migrate some ops into pir (#59041)
1 parent 43a9fe0 commit aac8b6d

File tree

5 files changed

+33
-17
lines changed

5 files changed

+33
-17
lines changed

python/paddle/tensor/math.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ def multiplex(inputs, index, name=None):
414414
[3., 4.]])
415415
416416
"""
417-
if in_dynamic_mode():
417+
if in_dynamic_or_pir_mode():
418418
return _C_ops.multiplex(inputs, index)
419419
else:
420420
helper = LayerHelper('multiplex', **locals())
@@ -2406,7 +2406,7 @@ def renorm(x, p, axis, max_norm):
24062406
)
24072407
)
24082408
axis = axis + len(input_shape)
2409-
if in_dynamic_mode():
2409+
if in_dynamic_or_pir_mode():
24102410
out = _C_ops.renorm(x, p, axis, max_norm)
24112411
return out
24122412
else:
@@ -5420,7 +5420,7 @@ def rad2deg(x, name=None):
54205420
57.29578018)
54215421
"""
54225422
rad2deg_scale = 180 / np.pi
5423-
if in_dynamic_mode():
5423+
if in_dynamic_or_pir_mode():
54245424
if convert_dtype(x.dtype) in ['int32', 'int64']:
54255425
x = cast(x, dtype="float32")
54265426
return _C_ops.scale(x, rad2deg_scale, 0.0, True)
@@ -6630,7 +6630,7 @@ def nextafter(x, y, name=None):
66306630
Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True,
66316631
[1.00000012, 1.99999988])
66326632
"""
6633-
if in_dynamic_mode():
6633+
if in_dynamic_or_pir_mode():
66346634
return _C_ops.nextafter(x, y)
66356635
else:
66366636
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'nextafter')

test/legacy_test/test_multiplex_op.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,19 +45,25 @@ def setUp(self):
4545
self.outputs = {'Out': output}
4646

4747
def test_check_output(self):
48-
self.check_output()
48+
self.check_output(check_pir=True)
4949

5050
def test_check_grad(self):
51-
self.check_grad(['x1', 'x2', 'x3', 'x4'], 'Out')
51+
self.check_grad(['x1', 'x2', 'x3', 'x4'], 'Out', check_pir=True)
5252

5353
def test_check_grad_ignore_x1(self):
54-
self.check_grad(['x2', 'x3', 'x4'], 'Out', no_grad_set=set('x1'))
54+
self.check_grad(
55+
['x2', 'x3', 'x4'], 'Out', no_grad_set=set('x1'), check_pir=True
56+
)
5557

5658
def test_check_grad_ignore_x1_x2(self):
57-
self.check_grad(['x3', 'x4'], 'Out', no_grad_set={'x1', 'x2'})
59+
self.check_grad(
60+
['x3', 'x4'], 'Out', no_grad_set={'x1', 'x2'}, check_pir=True
61+
)
5862

5963
def test_check_grad_ignore_x3(self):
60-
self.check_grad(['x1', 'x2', 'x4'], 'Out', no_grad_set=set('x3'))
64+
self.check_grad(
65+
['x1', 'x2', 'x4'], 'Out', no_grad_set=set('x3'), check_pir=True
66+
)
6167

6268

6369
class TestMultiplexOpError(unittest.TestCase):

test/legacy_test/test_nextafter_op.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from op_test import OpTest
1919

2020
import paddle
21+
from paddle.pir_utils import test_with_pir_api
2122

2223

2324
def ref_nextafter(x, y):
@@ -39,6 +40,7 @@ def setUp(self):
3940
else paddle.CPUPlace()
4041
)
4142

43+
@test_with_pir_api
4244
def test_static_api(self):
4345
paddle.enable_static()
4446
with paddle.static.program_guard(paddle.static.Program()):
@@ -103,7 +105,7 @@ def setUp(self):
103105
self.outputs = {'out': out}
104106

105107
def test_check_output(self):
106-
self.check_output()
108+
self.check_output(check_pir=True)
107109

108110
def init_dtype(self):
109111
self.dtype = np.float64

test/legacy_test/test_rad2deg.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import paddle
2020
from paddle import base
2121
from paddle.base import core
22+
from paddle.pir_utils import test_with_pir_api
2223

2324
paddle.enable_static()
2425

@@ -32,10 +33,11 @@ def setUp(self):
3233
self.x_shape = [6]
3334
self.out_np = np.rad2deg(self.x_np)
3435

36+
@test_with_pir_api
3537
def test_static_graph(self):
36-
startup_program = base.Program()
37-
train_program = base.Program()
38-
with base.program_guard(startup_program, train_program):
38+
startup_program = paddle.static.Program()
39+
train_program = paddle.static.Program()
40+
with paddle.static.program_guard(startup_program, train_program):
3941
x = paddle.static.data(
4042
name='input', dtype=self.x_dtype, shape=self.x_shape
4143
)
@@ -48,11 +50,10 @@ def test_static_graph(self):
4850
)
4951
exe = base.Executor(place)
5052
res = exe.run(
51-
base.default_main_program(),
5253
feed={'input': self.x_np},
5354
fetch_list=[out],
5455
)
55-
self.assertTrue((np.array(out[0]) == self.out_np).all())
56+
np.testing.assert_allclose(self.out_np, res[0], rtol=1e-05)
5657

5758
def test_dygraph(self):
5859
paddle.disable_static()
@@ -96,3 +97,7 @@ def test_dygraph(self):
9697
np.testing.assert_allclose(180 / np.pi, result2.numpy(), rtol=1e-05)
9798

9899
paddle.enable_static()
100+
101+
102+
if __name__ == "__main__":
103+
unittest.main()

test/legacy_test/test_renorm_op.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import paddle
2020
from paddle import base
21-
from paddle.base import Program, program_guard
21+
from paddle.pir_utils import test_with_pir_api
2222

2323
paddle.set_device('cpu')
2424

@@ -32,12 +32,15 @@ def input_data(self):
3232
self.dim = 2
3333
self.max_norm = 2.05
3434

35+
@test_with_pir_api
3536
def test_renorm_api(self):
3637
paddle.enable_static()
3738
self.input_data()
3839

3940
# case 1:
40-
with program_guard(Program(), Program()):
41+
with paddle.static.program_guard(
42+
paddle.static.Program(), paddle.static.Program()
43+
):
4144
x = paddle.static.data(name="x", shape=[-1, 2, 3], dtype='float64')
4245
z = paddle.renorm(x, self.p, self.dim, self.max_norm)
4346
exe = base.Executor(base.CPUPlace())

0 commit comments

Comments
 (0)