Skip to content
Merged
4 changes: 1 addition & 3 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@
real,
shape,
)
from .tensor.compat_softmax import softmax
from .tensor.creation import (
BFloat16Tensor,
BoolTensor,
Expand Down Expand Up @@ -629,9 +630,6 @@
where,
where_,
)
from .tensor.softmax import (
softmax,
)
from .tensor.stat import (
mean,
median,
Expand Down
12 changes: 11 additions & 1 deletion python/paddle/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,15 @@
sort,
split,
)
from .tensor.compat_softmax import softmax

__all__ = ['split', 'sort', 'Unfold', 'min', 'max', 'median', 'nanmedian']
__all__ = [
'softmax',
'split',
'sort',
'Unfold',
'min',
'max',
'median',
'nanmedian',
]
183 changes: 180 additions & 3 deletions python/paddle/nn/functional/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from paddle.framework import core, in_dynamic_or_pir_mode
from paddle.utils.decorator_utils import (
param_one_alias,
softmax_param_alias,
)
from paddle.utils.inplace_utils import inplace_apis_in_dygraph_only

Expand All @@ -30,14 +31,11 @@
from ...tensor.manipulation import chunk
from ...tensor.math import tanh, tanh_ # noqa: F401
from ...tensor.ops import sigmoid
from ...tensor.softmax import softmax as softmax

if TYPE_CHECKING:
from paddle import Tensor
from paddle._typing import DataLayout2D, DTypeLike

__all__ = []


def celu(x: Tensor, alpha: float = 1.0, name: str | None = None) -> Tensor:
r"""
Expand Down Expand Up @@ -1137,6 +1135,185 @@ def silu(x: Tensor, name: str | None = None) -> Tensor:
return out


@softmax_param_alias
def softmax(
x: Tensor,
axis: int = -1,
dtype: DTypeLike | None = None,
name: str | None = None,
*,
out: Tensor | None = None,
) -> Tensor:
r"""
This operator implements the softmax layer. The calculation process is as follows:

1. The dimension :attr:`axis` of ``x`` will be permuted to the last.

2. Then ``x`` will be logically flattened to a 2-D matrix. The matrix's second
dimension(row length) is the same as the dimension :attr:`axis` of ``x``,
and the first dimension(column length) is the product of all other dimensions
of ``x``. For each row of the matrix, the softmax operator squashes the
K-dimensional(K is the width of the matrix, which is also the size of ``x``'s
dimension :attr:`axis`) vector of arbitrary real values to a K-dimensional
vector of real values in the range [0, 1] that add up to 1.

3. After the softmax operation is completed, the inverse operations of steps 1 and 2
are performed to restore the two-dimensional matrix to the same dimension as the ``x`` .

It computes the exponential of the given dimension and the sum of exponential
values of all the other dimensions in the K-dimensional vector input.
Then the ratio of the exponential of the given dimension and the sum of
exponential values of all the other dimensions is the output of the softmax
operator.

For each row :math:`i` and each column :math:`j` in the matrix, we have:

.. math::

softmax[i, j] = \frac{\exp(x[i, j])}{\sum_j(exp(x[i, j])}

Example:

.. code-block:: text

Case 1:
Input:
x.shape = [2, 3, 4]
x.data = [[[2.0, 3.0, 4.0, 5.0],
[3.0, 4.0, 5.0, 6.0],
[7.0, 8.0, 8.0, 9.0]],
[[1.0, 2.0, 3.0, 4.0],
[5.0, 6.0, 7.0, 8.0],
[6.0, 7.0, 8.0, 9.0]]]

Attrs:
axis = -1

Output:
out.shape = [2, 3, 4]
out.data = [[[0.0320586 , 0.08714432, 0.23688282, 0.64391426],
[0.0320586 , 0.08714432, 0.23688282, 0.64391426],
[0.07232949, 0.19661193, 0.19661193, 0.53444665]],
[[0.0320586 , 0.08714432, 0.23688282, 0.64391426],
[0.0320586 , 0.08714432, 0.23688282, 0.64391426],
[0.0320586 , 0.08714432, 0.23688282, 0.64391426]]]

Case 2:
Input:
x.shape = [2, 3, 4]
x.data = [[[2.0, 3.0, 4.0, 5.0],
[3.0, 4.0, 5.0, 6.0],
[7.0, 8.0, 8.0, 9.0]],
[[1.0, 2.0, 3.0, 4.0],
[5.0, 6.0, 7.0, 8.0],
[6.0, 7.0, 8.0, 9.0]]]
Attrs:
axis = 1

Output:
out.shape = [2, 3, 4]
out.data = [[[0.00657326, 0.00657326, 0.01714783, 0.01714783],
[0.01786798, 0.01786798, 0.04661262, 0.04661262],
[0.97555875, 0.97555875, 0.93623955, 0.93623955]],
[[0.00490169, 0.00490169, 0.00490169, 0.00490169],
[0.26762315, 0.26762315, 0.26762315, 0.26762315],
[0.72747516, 0.72747516, 0.72747516, 0.72747516]]]

Parameters:
x (Tensor): The input Tensor with data type bfloat16, float16, float32, float64.
axis (int, optional): The axis along which to perform softmax
calculations. It should be in range [-D, D), where D is the
rank of ``x`` . If ``axis`` < 0, it works the same way as
:math:`axis + D` . Default is -1.
dtype (str, optional): The data type of the output tensor, can be bfloat16, float16, float32, float64.
name (str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
out (Tensor, optional): The output Tensor.

Returns:
A Tensor with the same shape and data type (use ``dtype`` if it is
specified) as x.

Examples:
.. code-block:: python

>>> import paddle
>>> import paddle.nn.functional as F

>>> x = paddle.to_tensor([[[2.0, 3.0, 4.0, 5.0],
... [3.0, 4.0, 5.0, 6.0],
... [7.0, 8.0, 8.0, 9.0]],
... [[1.0, 2.0, 3.0, 4.0],
... [5.0, 6.0, 7.0, 8.0],
... [6.0, 7.0, 8.0, 9.0]]],dtype='float32')
>>> out1 = F.softmax(x)
>>> out2 = F.softmax(x, dtype='float64')
>>> #out1's data type is float32; out2's data type is float64
>>> #out1 and out2's value is as follows:
>>> print(out1)
>>> print(out2)
Tensor(shape=[2, 3, 4], dtype=float32, place=Place(cpu), stop_gradient=True,
[[[0.03205860, 0.08714432, 0.23688284, 0.64391428],
[0.03205860, 0.08714432, 0.23688284, 0.64391428],
[0.07232949, 0.19661194, 0.19661194, 0.53444666]],
[[0.03205860, 0.08714432, 0.23688284, 0.64391428],
[0.03205860, 0.08714432, 0.23688284, 0.64391428],
[0.03205860, 0.08714432, 0.23688284, 0.64391428]]])
Tensor(shape=[2, 3, 4], dtype=float64, place=Place(cpu), stop_gradient=True,
[[[0.03205860, 0.08714432, 0.23688282, 0.64391426],
[0.03205860, 0.08714432, 0.23688282, 0.64391426],
[0.07232949, 0.19661193, 0.19661193, 0.53444665]],
[[0.03205860, 0.08714432, 0.23688282, 0.64391426],
[0.03205860, 0.08714432, 0.23688282, 0.64391426],
[0.03205860, 0.08714432, 0.23688282, 0.64391426]]])
"""
if (
(dtype is not None)
and (not isinstance(dtype, core.VarDesc.VarType))
and (not isinstance(dtype, core.DataType))
):
dtype = convert_np_dtype_to_dtype_(dtype)
if in_dynamic_or_pir_mode():
outs_cast = x if dtype is None else _C_ops.cast(x, dtype)
return _C_ops.softmax(outs_cast, axis, out=out)
else:
use_cudnn = True
if dtype is None:
check_variable_and_dtype(
x, 'x', ['uint16', 'float16', 'float32', 'float64'], 'softmax'
)
else:
check_dtype(
dtype,
'dtype',
['uint16', 'float16', 'float32', 'float64'],
'softmax',
'If dtype is not None, it only support uint16, float16, float32 or float64.',
)

helper = LayerHelper("softmax", **locals())
outs_cast = x
if dtype is not None:
outs_cast = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type='cast',
inputs={'X': x},
outputs={'Out': outs_cast},
attrs={'in_dtype': x.dtype, 'out_dtype': dtype},
)

outs_softmax = helper.create_variable_for_type_inference(
outs_cast.dtype
)
helper.append_op(
type='softmax',
inputs={'X': outs_cast},
outputs={'Out': outs_softmax},
attrs={'axis': axis, 'use_cudnn': use_cudnn},
)

return outs_softmax


@inplace_apis_in_dygraph_only
def softmax_(
x: Tensor,
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .tensor.compat_softmax import softmax
from .tensor.math import logsumexp

__all__ = [
"logsumexp",
"softmax",
]
2 changes: 1 addition & 1 deletion python/paddle/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
real,
shape,
)
from .compat_softmax import softmax as softmax
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个得加到tensor_method_func,才会bind到paddle.Tensor上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tensor_method_func一直都有加

from .creation import ( # noqa: F401
MmapStorage,
arange,
Expand Down Expand Up @@ -481,7 +482,6 @@
where,
where_,
)
from .softmax import softmax as softmax
from .stat import ( # noqa: F401
mean,
median,
Expand Down
1 change: 0 additions & 1 deletion python/paddle/tensor/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
Size2,
)


from paddle import nn
from paddle.utils.decorator_utils import ForbidKeywordsDecorator

Expand Down
Loading
Loading