Skip to content

Commit 894713f

Browse files
【PIR API adaptor No.105】 Migrate paddle.i0 into pir (#58603)
* ✨ Refactor: enable new ir op and added new ir test * Update python/paddle/tensor/math.py Co-authored-by: Lu Qi <[email protected]> * ♻️ Refactor: updated test * 🎨 Fix: updated code style --------- Co-authored-by: Lu Qi <[email protected]>
1 parent 004312c commit 894713f

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

python/paddle/tensor/math.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6664,7 +6664,7 @@ def i0(x, name=None):
66646664
Tensor(shape=[5], dtype=float32, place=Place(cpu), stop_gradient=True,
66656665
[0.99999994 , 1.26606596 , 2.27958512 , 4.88079262 , 11.30192089])
66666666
"""
6667-
if in_dynamic_mode():
6667+
if in_dynamic_or_pir_mode():
66686668
return _C_ops.i0(x)
66696669
else:
66706670
check_variable_and_dtype(x, "x", ["float32", "float64"], "i0")

test/legacy_test/test_i0_op.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import paddle
2222
from paddle.base import core
23+
from paddle.pir_utils import test_with_pir_api
2324

2425
np.random.seed(100)
2526
paddle.seed(100)
@@ -40,10 +41,12 @@ class TestI0API(unittest.TestCase):
4041

4142
def setUp(self):
4243
self.x = np.array(self.DATA).astype(self.DTYPE)
44+
self.out_ref = output_i0(self.x)
4345
self.place = [paddle.CPUPlace()]
4446
if core.is_compiled_with_cuda():
4547
self.place.append(paddle.CUDAPlace(0))
4648

49+
@test_with_pir_api
4750
def test_api_static(self):
4851
def run(place):
4952
paddle.enable_static()
@@ -58,8 +61,7 @@ def run(place):
5861
feed={"x": self.x},
5962
fetch_list=[out],
6063
)
61-
out_ref = output_i0(self.x)
62-
np.testing.assert_allclose(res[0], out_ref, rtol=1e-5)
64+
np.testing.assert_allclose(res[0], self.out_ref, rtol=1e-5)
6365
paddle.disable_static()
6466

6567
for place in self.place:
@@ -130,13 +132,14 @@ def init_config(self):
130132
self.target = output_i0(self.inputs['x'])
131133

132134
def test_check_output(self):
133-
self.check_output()
135+
self.check_output(check_pir=True)
134136

135137
def test_check_grad(self):
136138
self.check_grad(
137139
['x'],
138140
'out',
139141
user_defined_grads=[ref_i0_grad(self.case, 1 / self.case.size)],
142+
check_pir=True,
140143
)
141144

142145

0 commit comments

Comments
 (0)