Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -6828,7 +6828,7 @@ def polygamma(x, n, name=None):
if n == 0:
return digamma(x)
else:
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.polygamma(x, n)
else:
check_variable_and_dtype(
Expand Down Expand Up @@ -6864,7 +6864,7 @@ def polygamma_(x, n, name=None):
if n == 0:
return digamma_(x)
else:
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.polygamma_(x, n)


Expand Down
5 changes: 4 additions & 1 deletion test/legacy_test/test_i0e_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import paddle
from paddle.base import core
from paddle.pir_utils import test_with_pir_api

np.random.seed(100)
paddle.seed(100)
Expand Down Expand Up @@ -48,6 +49,7 @@ def setUp(self):
if core.is_compiled_with_cuda():
self.place.append(paddle.CUDAPlace(0))

@test_with_pir_api
def test_api_static(self):
def run(place):
paddle.enable_static()
Expand Down Expand Up @@ -134,13 +136,14 @@ def init_config(self):
self.target = output_i0e(self.inputs['x'])

def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)

def test_check_grad(self):
self.check_grad(
['x'],
'out',
user_defined_grads=[ref_i0e_grad(self.case, 1 / self.case.size)],
check_pir=True,
)


Expand Down
5 changes: 4 additions & 1 deletion test/legacy_test/test_polygamma_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import paddle
from paddle.base import core
from paddle.pir_utils import test_with_pir_api

np.random.seed(100)
paddle.seed(100)
Expand Down Expand Up @@ -64,6 +65,7 @@ def setUp(self):
if core.is_compiled_with_cuda():
self.place.append(paddle.CUDAPlace(0))

@test_with_pir_api
def test_api_static(self):
def run(place):
paddle.enable_static()
Expand Down Expand Up @@ -197,7 +199,7 @@ def init_config(self):
self.target = ref_polygamma(self.inputs['x'], self.order)

def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)

def test_check_grad(self):
self.check_grad(
Expand All @@ -206,6 +208,7 @@ def test_check_grad(self):
user_defined_grads=[
ref_polygamma_grad(self.case, 1 / self.case.size, self.order)
],
check_pir=True,
)


Expand Down