Skip to content

Commit 8022b63

Browse files
authored
【PIR API adaptor No.224】 Migrate paddle.tan (#58737)
* feat(new-ir): support tan * feat(new-ir): update ut of tan
1 parent 326352b commit 8022b63

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

python/paddle/tensor/ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1147,7 +1147,7 @@ def tan(x, name=None):
11471147
Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True,
11481148
[-0.42279324, -0.20271003, 0.10033467, 0.30933627])
11491149
"""
1150-
if in_dynamic_mode():
1150+
if in_dynamic_or_pir_mode():
11511151
return _C_ops.tan(x)
11521152
else:
11531153
check_variable_and_dtype(

test/legacy_test/test_activation_op.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1976,10 +1976,13 @@ def setUp(self):
19761976
def init_shape(self):
19771977
self.shape = [10, 12]
19781978

1979+
def test_check_output(self):
1980+
self.check_output(check_pir=True)
1981+
19791982
def test_check_grad(self):
19801983
if self.dtype == np.float16:
19811984
return
1982-
self.check_grad(['X'], 'Out')
1985+
self.check_grad(['X'], 'Out', check_pir=True)
19831986

19841987

19851988
class TestTan_float32(TestTan):
@@ -2020,6 +2023,7 @@ def test_dygraph_api(self):
20202023
out_ref = np.tan(self.x_np)
20212024
np.testing.assert_allclose(out_ref, out_test.numpy(), rtol=1e-05)
20222025

2026+
@test_with_pir_api
20232027
def test_static_api(self):
20242028
with static_guard():
20252029
with paddle.static.program_guard(paddle.static.Program()):
@@ -4815,7 +4819,7 @@ def test_check_grad(self):
48154819
check_pir=True,
48164820
)
48174821
create_test_act_fp16_class(TestCos, check_pir=True)
4818-
create_test_act_fp16_class(TestTan)
4822+
create_test_act_fp16_class(TestTan, check_pir=True)
48194823
create_test_act_fp16_class(TestCosh)
48204824
create_test_act_fp16_class(TestAcos, check_pir=True)
48214825
create_test_act_fp16_class(TestSin, check_pir=True)
@@ -4976,7 +4980,7 @@ def test_check_grad(self):
49764980
TestFloor, grad_check=False, check_prim=True, check_pir=True
49774981
)
49784982
create_test_act_bf16_class(TestCos, check_pir=True)
4979-
create_test_act_bf16_class(TestTan)
4983+
create_test_act_bf16_class(TestTan, check_pir=True)
49804984
create_test_act_bf16_class(TestCosh)
49814985
create_test_act_bf16_class(TestAcos, check_pir=True)
49824986
create_test_act_bf16_class(TestSin, check_pir=True)

0 commit comments

Comments
 (0)