Skip to content

Commit 3fc1550

Browse files
authored
【PIR api adaptor No.233、234】 Migrate paddle.trunc/frac into pir (#58675)
1 parent 38e314e commit 3fc1550

File tree

3 files changed

+21
-12
lines changed

3 files changed

+21
-12
lines changed

python/paddle/tensor/math.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import paddle
2222
from paddle import _C_ops, _legacy_C_ops
23+
from paddle.base.libpaddle import DataType
2324
from paddle.common_ops_import import VarDesc, dygraph_utils
2425
from paddle.utils.inplace_utils import inplace_apis_in_dygraph_only
2526

@@ -2026,7 +2027,7 @@ def trunc(input, name=None):
20262027
[[ 0., 1.],
20272028
[-0., -2.]])
20282029
'''
2029-
if in_dynamic_mode():
2030+
if in_dynamic_or_pir_mode():
20302031
return _C_ops.trunc(input)
20312032
else:
20322033
inputs = {"X": input}
@@ -6061,11 +6062,15 @@ def frac(x, name=None):
60616062
paddle.int64,
60626063
paddle.float32,
60636064
paddle.float64,
6065+
DataType.INT32,
6066+
DataType.INT64,
6067+
DataType.FLOAT32,
6068+
DataType.FLOAT64,
60646069
]:
60656070
raise TypeError(
60666071
f"The data type of input must be one of ['int32', 'int64', 'float32', 'float64'], but got {x.dtype}"
60676072
)
6068-
if in_dynamic_mode():
6073+
if in_dynamic_or_pir_mode():
60696074
y = _C_ops.trunc(x)
60706075
return _C_ops.subtract(x, y)
60716076
else:

test/legacy_test/test_frac_api.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818

1919
import paddle
2020
from paddle import base
21-
from paddle.base import Program, core, program_guard
21+
from paddle.base import core
22+
from paddle.pir_utils import test_with_pir_api
2223

2324

2425
def ref_frac(x):
@@ -40,15 +41,13 @@ def setUp(self):
4041
else paddle.CPUPlace()
4142
)
4243

44+
@test_with_pir_api
4345
def test_api_static(self):
4446
paddle.enable_static()
45-
with program_guard(Program()):
47+
with paddle.static.program_guard(paddle.static.Program()):
4648
input = paddle.static.data('X', self.x_np.shape, self.x_np.dtype)
4749
out = paddle.frac(input)
48-
place = base.CPUPlace()
49-
if base.core.is_compiled_with_cuda():
50-
place = base.CUDAPlace(0)
51-
exe = base.Executor(place)
50+
exe = base.Executor(self.place)
5251
(res,) = exe.run(feed={'X': self.x_np}, fetch_list=[out])
5352
out_ref = ref_frac(self.x_np)
5453
np.testing.assert_allclose(out_ref, res, rtol=1e-05)
@@ -101,6 +100,7 @@ def setUp(self):
101100
else paddle.CPUPlace()
102101
)
103102

103+
@test_with_pir_api
104104
def test_static_error(self):
105105
paddle.enable_static()
106106
with paddle.static.program_guard(paddle.static.Program()):

test/legacy_test/test_trunc_op.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import paddle
2121
from paddle.base import core
22+
from paddle.pir_utils import test_with_pir_api
2223

2324
paddle.enable_static()
2425

@@ -36,10 +37,10 @@ def init_dtype_type(self):
3637
self.dtype = np.float64
3738

3839
def test_check_output(self):
39-
self.check_output()
40+
self.check_output(check_pir=True)
4041

4142
def test_check_grad(self):
42-
self.check_grad(['X'], 'Out', numeric_grad_delta=1e-5)
43+
self.check_grad(['X'], 'Out', numeric_grad_delta=1e-5, check_pir=True)
4344

4445

4546
class TestFloatTruncOp(TestTruncOp):
@@ -66,6 +67,7 @@ def setUp(self):
6667
self.x = np.random.random((20, 20)).astype(np.float32)
6768
self.place = paddle.CPUPlace()
6869

70+
@test_with_pir_api
6971
def test_api_static(self):
7072
paddle.enable_static()
7173
with paddle.static.program_guard(paddle.static.Program()):
@@ -114,11 +116,13 @@ def setUp(self):
114116

115117
def test_check_output(self):
116118
place = core.CUDAPlace(0)
117-
self.check_output_with_place(place)
119+
self.check_output_with_place(place, check_pir=True)
118120

119121
def test_check_grad(self):
120122
place = core.CUDAPlace(0)
121-
self.check_grad_with_place(place, ['X'], 'Out', numeric_grad_delta=1e-5)
123+
self.check_grad_with_place(
124+
place, ['X'], 'Out', numeric_grad_delta=1e-5, check_pir=True
125+
)
122126

123127

124128
if __name__ == "__main__":

0 commit comments

Comments
 (0)