Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -6659,7 +6659,7 @@ def i0(x, name=None):
Tensor(shape=[5], dtype=float32, place=Place(cpu), stop_gradient=True,
[0.99999994 , 1.26606596 , 2.27958512 , 4.88079262 , 11.30192089])
"""
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.i0(x)
else:
check_variable_and_dtype(x, "x", ["float32", "float64"], "i0")
Expand All @@ -6677,7 +6677,7 @@ def i0_(x, name=None):
Please refer to :ref:`api_paddle_i0`.
"""

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.i0_(x)


Expand Down
5 changes: 4 additions & 1 deletion test/legacy_test/test_i0_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import paddle
from paddle.base import core
from paddle.pir_utils import test_with_pir_api

np.random.seed(100)
paddle.seed(100)
Expand All @@ -44,6 +45,7 @@ def setUp(self):
if core.is_compiled_with_cuda():
self.place.append(paddle.CUDAPlace(0))

@test_with_pir_api
def test_api_static(self):
def run(place):
paddle.enable_static()
Expand Down Expand Up @@ -130,13 +132,14 @@ def init_config(self):
self.target = output_i0(self.inputs['x'])

def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)

def test_check_grad(self):
self.check_grad(
['x'],
'out',
user_defined_grads=[ref_i0_grad(self.case, 1 / self.case.size)],
check_pir=True,
)


Expand Down