Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion paddle/fluid/pybind/arg_pre_process.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,39 @@
// processing of parameters originally done in the Python API
#include "paddle/fluid/pybind/arg_pre_process.h"
#include "paddle/fluid/eager/utils.h"
#include "paddle/fluid/pir/utils/general_functions.h"
#include "paddle/fluid/pybind/eager_utils.h"
#include "paddle/fluid/pybind/op_function_common.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/enforce.h"
namespace paddle {
namespace pybind {} // namespace pybind
namespace pybind {
void LogsumexpPreProcess(Tensor *x, std::vector<int> *axis, bool *reduce_all) {
/**
if axis == [] or len(axis) == len(x.shape):
reduce_all = True
else:
reduce_all = False
*/
if (axis->empty() || axis->size() == x->dims().size()) {
*reduce_all = true;
} else {
*reduce_all = false;
}
return;
}

void LogsumexpPreProcess(pir::Value *x,
std::vector<int> *axis,
bool *reduce_all) {
std::vector<int64_t> x_shape = pir::GetShapeFromValue(*x);
if (axis->empty() || axis->size() == x_shape.size()) {
*reduce_all = true;
} else {
*reduce_all = false;
}
return;
}
} // namespace pybind

} // namespace paddle
12 changes: 11 additions & 1 deletion paddle/fluid/pybind/arg_pre_process.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,19 @@
#pragma once

#include <Python.h>
#include <vector>
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/pir/include/core/value.h"

namespace paddle {

namespace pybind {} // namespace pybind
namespace pybind {
using Value = pir::Value;

void LogsumexpPreProcess(Tensor *x, std::vector<int> *axis, bool *reduce_all);
void LogsumexpPreProcess(Value *x, std::vector<int> *axis, bool *reduce_all);
} // namespace pybind

} // namespace paddle
4 changes: 3 additions & 1 deletion paddle/fluid/pybind/op_function_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,8 @@ std::vector<int> CastPyArg2Ints(PyObject* obj,
}
Py_DECREF(item);
}
} else if (PyObject_CheckLong(obj)) {
value.emplace_back(PyObject_ToInt32(obj));
} else {
PADDLE_THROW(common::errors::InvalidType(
"%s(): argument (position %d) must be "
Expand All @@ -666,7 +668,7 @@ std::vector<int> CastPyArg2Ints(PyObject* obj,
const std::string& op_type,
ssize_t arg_pos,
std::vector<int> default_value) {
if (obj != nullptr) {
if (obj != nullptr && obj != Py_None) {
return CastPyArg2Ints(obj, op_type, arg_pos);
} else {
return default_value;
Expand Down
8 changes: 7 additions & 1 deletion paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3360,7 +3360,13 @@
traits : paddle::dialect::ForwardOnlyTrait

- op : logsumexp
args : (Tensor x, int[] axis={0}, bool keepdim=false, bool reduce_all=false)
args : (Tensor x, int[] axis={}, bool keepdim=false, bool reduce_all=false)
python_api:
name : [paddle.logsumexp,paddle.Tensor.logsumexp]
args_alias:
use_default_mapping : True
pre_process:
func : LogsumexpPreProcess(x, axis, reduce_all)
output : Tensor(out)
infer_meta :
func : LogsumexpInferMeta
Expand Down
63 changes: 63 additions & 0 deletions python/paddle/_paddle_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,69 @@ def all(
""",
)

add_doc_and_signature(
"logsumexp",
r"""
Calculates the log of the sum of exponentials of ``x`` along ``axis`` .

.. math::
logsumexp(x) = \log\sum exp(x)

Args:
x (Tensor): The input Tensor with data type bfloat16, float16, float32,
float64, uint8, int8, int16, int32, int64, which have no more than
4 dimensions.
axis (int|list|tuple|None, optional): The axis along which to perform
logsumexp calculations. ``axis`` should be int, list(int) or
tuple(int). If ``axis`` is a list/tuple of dimension(s), logsumexp
is calculated along all element(s) of ``axis`` . ``axis`` or
element(s) of ``axis`` should be in range [-D, D), where D is the
dimensions of ``x`` . If ``axis`` or element(s) of ``axis`` is
less than 0, it works the same way as :math:`axis + D` . If
``axis`` is None, logsumexp is calculated along all elements of
``x``. Default is None.
keepdim (bool, optional): Whether to reserve the reduced dimension(s)
in the output Tensor. If ``keep_dim`` is True, the dimensions of
the output Tensor is the same as ``x`` except in the reduced
dimensions(it is of size 1 in this case). Otherwise, the shape of
the output Tensor is squeezed in ``axis`` . Default is False.
name (str|None, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Keyword Args:
out (Tensor|optional): The output tensor.
Returns:
Tensor, results of logsumexp along ``axis`` of ``x``, with the same data
type as ``x`` (integer types are autocasted into float32).

Examples:

.. code-block:: python

>>> import paddle

>>> x = paddle.to_tensor([[-1.5, 0., 2.], [3., 1.2, -2.4]])
>>> out1 = paddle.logsumexp(x)
>>> out1
Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
3.46912265)
>>> out2 = paddle.logsumexp(x, 1)
>>> out2
Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True,
[2.15317822, 3.15684605])

""",
"""
def logsumexp(
x: Tensor,
axis: int | Sequence[int] | None = None,
keepdim: bool = False,
name: str | None = None,
*,
out: Tensor | None = None,
) -> Tensor
""",
)

# zhengsheng
add_doc_and_signature(
"isfinite",
Expand Down
90 changes: 1 addition & 89 deletions python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
isfinite,
isinf,
isnan,
logsumexp,
)
from paddle.base.libpaddle import DataType
from paddle.common_ops_import import VarDesc, dygraph_utils
Expand Down Expand Up @@ -3156,95 +3157,6 @@ def __check_input(x, y):
return out


def logsumexp(
x: Tensor,
axis: int | Sequence[int] | None = None,
keepdim: bool = False,
name: str | None = None,
*,
out: Tensor | None = None,
) -> Tensor:
r"""
Calculates the log of the sum of exponentials of ``x`` along ``axis`` .

.. math::
logsumexp(x) = \log\sum exp(x)

Args:
x (Tensor): The input Tensor with data type bfloat16, float16, float32,
float64, uint8, int8, int16, int32, int64, which have no more than
4 dimensions.
axis (int|list|tuple|None, optional): The axis along which to perform
logsumexp calculations. ``axis`` should be int, list(int) or
tuple(int). If ``axis`` is a list/tuple of dimension(s), logsumexp
is calculated along all element(s) of ``axis`` . ``axis`` or
element(s) of ``axis`` should be in range [-D, D), where D is the
dimensions of ``x`` . If ``axis`` or element(s) of ``axis`` is
less than 0, it works the same way as :math:`axis + D` . If
``axis`` is None, logsumexp is calculated along all elements of
``x``. Default is None.
keepdim (bool, optional): Whether to reserve the reduced dimension(s)
in the output Tensor. If ``keep_dim`` is True, the dimensions of
the output Tensor is the same as ``x`` except in the reduced
dimensions(it is of size 1 in this case). Otherwise, the shape of
the output Tensor is squeezed in ``axis`` . Default is False.
name (str|None, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
out (Tensor|None, optional): The output Tensor. If set, the result will be
stored in this Tensor.

Returns:
Tensor, results of logsumexp along ``axis`` of ``x``, with the same data
type as ``x`` (integer types are autocasted into float32).

Examples:

.. code-block:: python

>>> import paddle

>>> x = paddle.to_tensor([[-1.5, 0., 2.], [3., 1.2, -2.4]])
>>> out1 = paddle.logsumexp(x)
>>> out1
Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
3.46912265)
>>> out2 = paddle.logsumexp(x, 1)
>>> out2
Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True,
[2.15317822, 3.15684605])

"""
reduce_all, axis = _get_reduce_axis(axis, x)

if in_dynamic_or_pir_mode():
return _C_ops.logsumexp(x, axis, keepdim, reduce_all, out=out)
else:
check_variable_and_dtype(
x,
'x',
[
'float16',
'float32',
'float64',
'uint16',
'uint8',
'int8',
'int16',
'int32',
'int64',
],
'logsumexp',
)

helper = LayerHelper('logsumexp', **locals())
attrs = {'axis': axis, 'keepdim': keepdim, 'reduce_all': reduce_all}
out = helper.create_variable_for_type_inference(x.dtype)
helper.append_op(
type='logsumexp', inputs={'X': x}, outputs={'Out': out}, attrs=attrs
)
return out


def inverse(x: Tensor, name: str | None = None) -> Tensor:
"""
Takes the inverse of the square matrix. A square matrix is a matrix with
Expand Down
77 changes: 77 additions & 0 deletions test/legacy_test/test_logsumexp.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,5 +396,82 @@ def test_logsumexp_out(self):
)


class TestLogsumexpAPI_Compatibility(unittest.TestCase):
def setUp(self):
np.random.seed(123)
paddle.enable_static()
self.shape = [5, 6]
self.dtype = 'float32'
self.init_data()

def init_data(self):
self.np_input = np.random.randint(0, 8, self.shape).astype(self.dtype)
self.np_ref_out = ref_logsumexp(
self.np_input, axis=[0, 1], keepdim=True, reduce_all=True
)

def test_dygraph_Compatibility(self):
paddle.disable_static()
x = paddle.to_tensor(self.np_input)
paddle_dygraph_out = []
# Position args (args)
out1 = paddle.logsumexp(x, [0, 1], True)
paddle_dygraph_out.append(out1)
# Key words args (kwargs) for paddle
out2 = paddle.logsumexp(x=x, axis=[0, 1], keepdim=True)
paddle_dygraph_out.append(out2)
# Key words args for torch
out3 = paddle.logsumexp(input=x, dim=[0, 1], keepdim=True)
paddle_dygraph_out.append(out3)
# Combined args and kwargs
out4 = paddle.logsumexp(x, dim=[0, 1], keepdim=True)
paddle_dygraph_out.append(out4)
# Tensor method args
out5 = x.logsumexp([0, 1], True)
paddle_dygraph_out.append(out5)
# Tensor method kwargs
out6 = x.logsumexp(dim=[0, 1], keepdim=True)
paddle_dygraph_out.append(out6)
# Test out
out7 = paddle.empty([])
paddle.logsumexp(x, [0, 1], True, out=out7)
paddle_dygraph_out.append(out7)
# Numpy reference out
ref_out = self.np_ref_out
# Check
for out in paddle_dygraph_out:
np.testing.assert_allclose(ref_out, out.numpy())
paddle.enable_static()

def test_static_Compatibility(self):
main = paddle.static.Program()
startup = paddle.static.Program()
with paddle.base.program_guard(main, startup):
x = paddle.static.data(name="x", shape=self.shape, dtype=self.dtype)
# Position args (args)
out1 = paddle.logsumexp(x, [0, 1], True)
# Key words args (kwargs) for paddle
out2 = paddle.logsumexp(x=x, axis=[0, 1], keepdim=True)
# Key words args for torch
out3 = paddle.logsumexp(input=x, dim=[0, 1], keepdim=True)
# Combined args and kwargs
out4 = paddle.logsumexp(x, dim=[0, 1], keepdim=True)
# Tensor method args
out5 = x.logsumexp([0, 1], True)
# Tensor method kwargs
out6 = x.logsumexp(dim=[0, 1], keepdim=True)
# Do not support out in static
# out7 = paddle.empty([])
exe = paddle.base.Executor(paddle.CPUPlace())
fetches = exe.run(
main,
feed={"x": self.np_input},
fetch_list=[out1, out2, out3, out4, out5, out6],
)
ref_out = self.np_ref_out
for out in fetches:
np.testing.assert_allclose(out, ref_out)


if __name__ == '__main__':
unittest.main()
Loading