Skip to content

Commit cbedff7

Browse files
[API Compatibility] paddle.sigmoid sink into C++ (#74802)
* feat(api sink): support paddle.sigmoid * feat(api sink): support paddle.sigmoid * feat(api sink): fix sigmoid doc
1 parent 5cb6b67 commit cbedff7

File tree

3 files changed

+206
-0
lines changed

3 files changed

+206
-0
lines changed

paddle/phi/ops/yaml/ops.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4954,6 +4954,10 @@
49544954

49554955
- op : sigmoid
49564956
args : (Tensor x)
4957+
python_api:
4958+
name : [paddle.sigmoid,paddle.Tensor.sigmoid,paddle.nn.functional.sigmoid]
4959+
args_alias:
4960+
use_default_mapping : True
49574961
output : Tensor
49584962
infer_meta :
49594963
func : UnchangedInferMeta

python/paddle/_paddle_docs.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,47 @@ def any(
590590

591591
# shenwei
592592

593+
add_doc_and_signature(
594+
"sigmoid",
595+
r"""
596+
Sigmoid Activation.
597+
598+
.. math::
599+
out = \\frac{1}{1 + e^{-x}}
600+
601+
Args:
602+
x (Tensor): Input of Sigmoid operator, an N-D Tensor, with data type bfloat16, float16, float32, float64,
603+
uint8, int8, int16, int32, int64, complex64 or complex128.
604+
name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
605+
Keyword Args:
606+
out (Tensor|optional): The output tensor.
607+
608+
Returns:
609+
Tensor. Output of Sigmoid operator, a Tensor with shape same as input
610+
(integer types are autocasted into float32).
611+
612+
Examples:
613+
.. code-block:: python
614+
615+
>>> import paddle
616+
>>> import paddle.nn.functional as F
617+
618+
>>> x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
619+
>>> out = F.sigmoid(x)
620+
>>> print(out)
621+
Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True,
622+
[0.40131235, 0.45016602, 0.52497917, 0.57444251])
623+
""",
624+
"""
625+
def sigmoid(
626+
x: paddle.Tensor,
627+
name: str | None = None,
628+
*,
629+
out: Tensor | None = None,
630+
) -> paddle.Tensor
631+
""",
632+
)
633+
593634
# zhouxin
594635

595636
# hehongyu

test/legacy_test/test_sigmoid.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import numpy as np
18+
from op_test import get_places
19+
20+
import paddle
21+
from paddle import base
22+
23+
24+
class TestSigmoidAPI_Compatibility(unittest.TestCase):
25+
def setUp(self):
26+
np.random.seed(123)
27+
paddle.enable_static()
28+
self.places = get_places()
29+
self.init_data()
30+
31+
def init_data(self):
32+
self.shape = [10, 15]
33+
self.dtype = "float32"
34+
self.np_input = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
35+
36+
def ref_forward(self, x):
37+
return 1 / (1 + np.exp(-x))
38+
39+
def test_dygraph_Compatibility(self):
40+
paddle.disable_static()
41+
x = paddle.to_tensor(self.np_input)
42+
paddle_dygraph_out = []
43+
# Position args (args)
44+
out1 = paddle.sigmoid(x)
45+
paddle_dygraph_out.append(out1)
46+
# Key words args (kwargs) for paddle
47+
out2 = paddle.sigmoid(x=x)
48+
paddle_dygraph_out.append(out2)
49+
# Key words args for torch
50+
out3 = paddle.sigmoid(input=x)
51+
paddle_dygraph_out.append(out3)
52+
# Tensor method args
53+
out4 = x.sigmoid()
54+
paddle_dygraph_out.append(out4)
55+
# Test out
56+
out5 = paddle.empty([])
57+
paddle.sigmoid(x, out=out5)
58+
paddle_dygraph_out.append(out5)
59+
# Reference output
60+
ref_out = self.ref_forward(self.np_input)
61+
# Check
62+
for i in range(len(paddle_dygraph_out)):
63+
np.testing.assert_allclose(
64+
ref_out, paddle_dygraph_out[i].numpy(), rtol=1e-05
65+
)
66+
paddle.enable_static()
67+
68+
def test_static_Compatibility(self):
69+
main = paddle.static.Program()
70+
startup = paddle.static.Program()
71+
with base.program_guard(main, startup):
72+
x = paddle.static.data(name="x", shape=self.shape, dtype=self.dtype)
73+
# Position args (args)
74+
out1 = paddle.sigmoid(x)
75+
# Key words args (kwargs) for paddle
76+
out2 = paddle.sigmoid(x=x)
77+
# Key words args for torch
78+
out3 = paddle.sigmoid(input=x)
79+
# Tensor method args
80+
out4 = x.sigmoid()
81+
exe = base.Executor(paddle.CPUPlace())
82+
fetches = exe.run(
83+
main,
84+
feed={"x": self.np_input},
85+
fetch_list=[out1, out2, out3, out4],
86+
)
87+
ref_out = self.ref_forward(self.np_input)
88+
for i in range(len(fetches)):
89+
np.testing.assert_allclose(fetches[i], ref_out, rtol=1e-05)
90+
91+
92+
class TestTensorSigmoidAPI_Compatibility(unittest.TestCase):
93+
def setUp(self):
94+
np.random.seed(123)
95+
paddle.enable_static()
96+
self.places = get_places()
97+
self.init_data()
98+
99+
def init_data(self):
100+
self.shape = [10, 15]
101+
self.dtype = "float32"
102+
self.np_input = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
103+
104+
def ref_forward(self, x):
105+
return 1 / (1 + np.exp(-x))
106+
107+
def test_dygraph_Compatibility(self):
108+
paddle.disable_static()
109+
x = paddle.to_tensor(self.np_input)
110+
paddle_dygraph_out = []
111+
# Position args (args)
112+
out1 = paddle.Tensor.sigmoid(x)
113+
paddle_dygraph_out.append(out1)
114+
# Key words args (kwargs) for paddle
115+
out2 = paddle.Tensor.sigmoid(x=x)
116+
paddle_dygraph_out.append(out2)
117+
# Key words args for torch
118+
out3 = paddle.Tensor.sigmoid(input=x)
119+
paddle_dygraph_out.append(out3)
120+
# Tensor method args
121+
out4 = x.sigmoid()
122+
paddle_dygraph_out.append(out4)
123+
# Test out
124+
out5 = paddle.empty([])
125+
paddle.Tensor.sigmoid(x, out=out5)
126+
paddle_dygraph_out.append(out5)
127+
# Reference output
128+
ref_out = self.ref_forward(self.np_input)
129+
# Check
130+
for i in range(len(paddle_dygraph_out)):
131+
np.testing.assert_allclose(
132+
ref_out, paddle_dygraph_out[i].numpy(), rtol=1e-05
133+
)
134+
paddle.enable_static()
135+
136+
def test_static_Compatibility(self):
137+
main = paddle.static.Program()
138+
startup = paddle.static.Program()
139+
with base.program_guard(main, startup):
140+
x = paddle.static.data(name="x", shape=self.shape, dtype=self.dtype)
141+
# Position args (args)
142+
out1 = paddle.Tensor.sigmoid(x)
143+
# Key words args (kwargs) for paddle
144+
out2 = paddle.Tensor.sigmoid(x=x)
145+
# Key words args for torch
146+
out3 = paddle.Tensor.sigmoid(input=x)
147+
# Tensor method args
148+
out4 = x.sigmoid()
149+
exe = base.Executor(paddle.CPUPlace())
150+
fetches = exe.run(
151+
main,
152+
feed={"x": self.np_input},
153+
fetch_list=[out1, out2, out3, out4],
154+
)
155+
ref_out = self.ref_forward(self.np_input)
156+
for i in range(len(fetches)):
157+
np.testing.assert_allclose(fetches[i], ref_out, rtol=1e-05)
158+
159+
160+
if __name__ == '__main__':
161+
unittest.main()

0 commit comments

Comments
 (0)