Skip to content

Commit da292d3

Browse files
authored
【PIR API adaptor No.225、227、197、187、152】 Migrate tanhshrink/thresholded_relu/Selu/RRelu/maxout into pir (#58429)
1 parent f119a4e commit da292d3

5 files changed

Lines changed: 173 additions & 96 deletions

File tree

python/paddle/nn/functional/activation.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -712,7 +712,7 @@ def rrelu(x, lower=1.0 / 8.0, upper=1.0 / 3.0, training=True, name=None):
712712

713713
is_test = not training
714714

715-
if in_dynamic_mode():
715+
if in_dynamic_or_pir_mode():
716716
return _C_ops.rrelu(x, lower, upper, is_test)
717717
else:
718718
check_variable_and_dtype(
@@ -886,7 +886,7 @@ def maxout(x, groups, axis=1, name=None):
886886
[0.42400089, 0.40641287, 0.97020894, 0.74437362],
887887
[0.51785129, 0.73292869, 0.97786582, 0.92382854]]]])
888888
"""
889-
if in_dynamic_mode():
889+
if in_dynamic_or_pir_mode():
890890
return _C_ops.maxout(x, groups, axis)
891891
else:
892892
check_variable_and_dtype(
@@ -1007,7 +1007,7 @@ def selu(
10071007
f"The alpha must be no less than zero. Received: {alpha}."
10081008
)
10091009

1010-
if in_dynamic_mode():
1010+
if in_dynamic_or_pir_mode():
10111011
return _C_ops.selu(x, scale, alpha)
10121012
else:
10131013
check_variable_and_dtype(
@@ -1533,7 +1533,7 @@ def tanhshrink(x, name=None):
15331533
Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True,
15341534
[-0.02005100, -0.00262472, 0.00033201, 0.00868741])
15351535
"""
1536-
if in_dynamic_mode():
1536+
if in_dynamic_or_pir_mode():
15371537
return _C_ops.tanh_shrink(x)
15381538
else:
15391539
check_variable_and_dtype(
@@ -1583,7 +1583,7 @@ def thresholded_relu(x, threshold=1.0, name=None):
15831583
[2., 0., 0.])
15841584
"""
15851585

1586-
if in_dynamic_mode():
1586+
if in_dynamic_or_pir_mode():
15871587
return _C_ops.thresholded_relu(x, threshold)
15881588
else:
15891589
check_variable_and_dtype(

test/legacy_test/test_activation_op.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1162,7 +1162,10 @@ def setUp(self):
11621162
def test_check_grad(self):
11631163
if self.dtype == np.float16:
11641164
return
1165-
self.check_grad(['X'], 'Out')
1165+
self.check_grad(['X'], 'Out', check_pir=True)
1166+
1167+
def test_check_output(self):
1168+
self.check_output(check_pir=True)
11661169

11671170

11681171
class TestTanhshrink_ZeroDim(TestTanhshrink):
@@ -1181,6 +1184,7 @@ def setUp(self):
11811184
else paddle.CPUPlace()
11821185
)
11831186

1187+
@test_with_pir_api
11841188
def test_static_api(self):
11851189
with static_guard():
11861190
with paddle.static.program_guard(paddle.static.Program()):
@@ -4317,7 +4321,10 @@ def init_shape(self):
43174321
def test_check_grad(self):
43184322
if self.dtype == np.float16:
43194323
return
4320-
self.check_grad(['X'], 'Out')
4324+
self.check_grad(['X'], 'Out', check_pir=True)
4325+
4326+
def test_check_output(self):
4327+
self.check_output(check_pir=True)
43214328

43224329

43234330
class TestThresholdedRelu_ZeroDim(TestThresholdedRelu):
@@ -4338,6 +4345,7 @@ def setUp(self):
43384345
else paddle.CPUPlace()
43394346
)
43404347

4348+
@test_with_pir_api
43414349
def test_static_api(self):
43424350
with static_guard():
43434351
with paddle.static.program_guard(paddle.static.Program()):
@@ -4805,7 +4813,7 @@ def test_check_grad(self):
48054813
create_test_act_fp16_class(
48064814
TestTanh, check_prim=True, check_prim_pir=True, enable_cinn=True
48074815
)
4808-
create_test_act_fp16_class(TestTanhshrink)
4816+
create_test_act_fp16_class(TestTanhshrink, check_pir=True)
48094817
create_test_act_fp16_class(TestHardShrink, check_pir=True)
48104818
create_test_act_fp16_class(TestSoftshrink, check_pir=True)
48114819
create_test_act_fp16_class(
@@ -4980,7 +4988,7 @@ def test_check_grad(self):
49804988
create_test_act_bf16_class(TestSilu, check_prim=True, check_prim_pir=True)
49814989
create_test_act_bf16_class(TestLogSigmoid)
49824990
create_test_act_bf16_class(TestTanh, check_prim=True, check_prim_pir=True)
4983-
create_test_act_bf16_class(TestTanhshrink)
4991+
create_test_act_bf16_class(TestTanhshrink, check_pir=True)
49844992
create_test_act_bf16_class(TestHardShrink, check_pir=True)
49854993
create_test_act_bf16_class(TestSoftshrink, check_pir=True)
49864994
create_test_act_bf16_class(

test/legacy_test/test_maxout_op.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import paddle
2121
import paddle.nn.functional as F
2222
from paddle.base import core
23+
from paddle.pir_utils import test_with_pir_api
2324

2425
paddle.enable_static()
2526
np.random.seed(1)
@@ -57,10 +58,10 @@ def set_attrs(self):
5758
pass
5859

5960
def test_check_output(self):
60-
self.check_output()
61+
self.check_output(check_pir=True)
6162

6263
def test_check_grad(self):
63-
self.check_grad(['X'], 'Out')
64+
self.check_grad(['X'], 'Out', check_pir=True)
6465

6566

6667
class TestMaxOutOpAxis0(TestMaxOutOp):
@@ -95,6 +96,7 @@ def setUp(self):
9596
else paddle.CPUPlace()
9697
)
9798

99+
@test_with_pir_api
98100
def test_static_api(self):
99101
with paddle.static.program_guard(paddle.static.Program()):
100102
x = paddle.static.data('X', self.x_np.shape, self.x_np.dtype)
@@ -161,6 +163,7 @@ def setUp(self):
161163
self.axis = 1
162164
self.place = paddle.CUDAPlace(0)
163165

166+
@test_with_pir_api
164167
def test_static_api(self):
165168
with paddle.static.program_guard(paddle.static.Program()):
166169
x = paddle.static.data('X', self.x_np.shape, self.x_np.dtype)

0 commit comments

Comments
 (0)