From f138fc790316671a31a4dcd07738251d593f073e Mon Sep 17 00:00:00 2001 From: DanielSun11 <1395924413@qq.com> Date: Sun, 24 Aug 2025 23:11:02 +0800 Subject: [PATCH 1/3] sink logsumexp to cpp --- paddle/fluid/pybind/arg_pre_process.cc | 30 +++++++- paddle/fluid/pybind/arg_pre_process.h | 12 +++- paddle/fluid/pybind/op_function_common.cc | 4 +- paddle/phi/ops/yaml/ops.yaml | 8 ++- python/paddle/_paddle_docs.py | 63 +++++++++++++++++ python/paddle/tensor/math.py | 86 +---------------------- test/legacy_test/test_logsumexp.py | 83 +++++++++++++++++++++- 7 files changed, 194 insertions(+), 92 deletions(-) diff --git a/paddle/fluid/pybind/arg_pre_process.cc b/paddle/fluid/pybind/arg_pre_process.cc index 1dd1e8c70e3c07..b1e19be512a6f5 100644 --- a/paddle/fluid/pybind/arg_pre_process.cc +++ b/paddle/fluid/pybind/arg_pre_process.cc @@ -19,11 +19,39 @@ // processing of parameters originally done in the Python API #include "paddle/fluid/pybind/arg_pre_process.h" #include "paddle/fluid/eager/utils.h" +#include "paddle/fluid/pir/utils/general_functions.h" #include "paddle/fluid/pybind/eager_utils.h" #include "paddle/fluid/pybind/op_function_common.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/core/enforce.h" namespace paddle { -namespace pybind {} // namespace pybind +namespace pybind { +void LogsumexpPreProcess(Tensor *x, std::vector *axis, bool *reduce_all) { + /** + if axis == [] or len(axis) == len(x.shape): + reduce_all = True + else: + reduce_all = False + */ + if (axis->empty() || axis->size() == x->dims().size()) { + *reduce_all = true; + } else { + *reduce_all = false; + } + return; +} + +void LogsumexpPreProcess(pir::Value *x, + std::vector *axis, + bool *reduce_all) { + std::vector x_shape = pir::GetShapeFromValue(*x); + if (axis->empty() || axis->size() == x_shape.size()) { + *reduce_all = true; + } else { + *reduce_all = false; + } + return; +} +} // namespace pybind } // namespace paddle diff --git a/paddle/fluid/pybind/arg_pre_process.h b/paddle/fluid/pybind/arg_pre_process.h index 557b6d1c5f4739..e3051ecc00139b 100644 --- a/paddle/fluid/pybind/arg_pre_process.h +++ b/paddle/fluid/pybind/arg_pre_process.h @@ -15,9 +15,19 @@ #pragma once #include +#include +#include "paddle/phi/api/include/tensor.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/common/scalar.h" +#include "paddle/pir/include/core/value.h" namespace paddle { -namespace pybind {} // namespace pybind +namespace pybind { +using Value = pir::Value; + +void LogsumexpPreProcess(Tensor *x, std::vector *axis, bool *reduce_all); +void LogsumexpPreProcess(Value *x, std::vector *axis, bool *reduce_all); +} // namespace pybind } // namespace paddle diff --git a/paddle/fluid/pybind/op_function_common.cc b/paddle/fluid/pybind/op_function_common.cc index 81a64d056b0a32..5786c64b922075 100644 --- a/paddle/fluid/pybind/op_function_common.cc +++ b/paddle/fluid/pybind/op_function_common.cc @@ -651,6 +651,8 @@ std::vector CastPyArg2Ints(PyObject* obj, } Py_DECREF(item); } + } else if (PyObject_CheckLong(obj)) { + value.emplace_back(PyObject_ToInt32(obj)); } else { PADDLE_THROW(common::errors::InvalidType( "%s(): argument (position %d) must be " @@ -666,7 +668,7 @@ std::vector CastPyArg2Ints(PyObject* obj, const std::string& op_type, ssize_t arg_pos, std::vector default_value) { - if (obj != nullptr) { + if (obj != nullptr && obj != Py_None) { return CastPyArg2Ints(obj, op_type, arg_pos); } else { return default_value; diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index d89552ba46ac47..e06cc97750ee76 100644 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -3368,7 +3368,13 @@ traits : paddle::dialect::ForwardOnlyTrait - op : logsumexp - args : (Tensor x, int[] axis={0}, bool keepdim=false, bool reduce_all=false) + args : (Tensor x, int[] axis={}, bool keepdim=false, bool reduce_all=false) + python_api: + name : [paddle.logsumexp,paddle.Tensor.logsumexp] + args_alias: + use_default_mapping : True + pre_process: + func : LogsumexpPreProcess(x, axis, reduce_all) output : Tensor(out) infer_meta : func : LogsumexpInferMeta diff --git a/python/paddle/_paddle_docs.py b/python/paddle/_paddle_docs.py index 02212c974e43e6..5b0b996b557a3e 100644 --- a/python/paddle/_paddle_docs.py +++ b/python/paddle/_paddle_docs.py @@ -401,6 +401,69 @@ def all( """, ) +add_doc_and_signature( + "logsumexp", + r""" + Calculates the log of the sum of exponentials of ``x`` along ``axis`` . + + .. math:: + logsumexp(x) = \log\sum exp(x) + + Args: + x (Tensor): The input Tensor with data type bfloat16, float16, float32, + float64, uint8, int8, int16, int32, int64, which have no more than + 4 dimensions. + axis (int|list|tuple|None, optional): The axis along which to perform + logsumexp calculations. ``axis`` should be int, list(int) or + tuple(int). If ``axis`` is a list/tuple of dimension(s), logsumexp + is calculated along all element(s) of ``axis`` . ``axis`` or + element(s) of ``axis`` should be in range [-D, D), where D is the + dimensions of ``x`` . If ``axis`` or element(s) of ``axis`` is + less than 0, it works the same way as :math:`axis + D` . If + ``axis`` is None, logsumexp is calculated along all elements of + ``x``. Default is None. + keepdim (bool, optional): Whether to reserve the reduced dimension(s) + in the output Tensor. If ``keep_dim`` is True, the dimensions of + the output Tensor is the same as ``x`` except in the reduced + dimensions(it is of size 1 in this case). Otherwise, the shape of + the output Tensor is squeezed in ``axis`` . Default is False. + name (str|None, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + Keyword Args: + out (Tensor|optional): The output tensor. + Returns: + Tensor, results of logsumexp along ``axis`` of ``x``, with the same data + type as ``x`` (integer types are autocasted into float32). + + Examples: + + .. code-block:: python + + >>> import paddle + + >>> x = paddle.to_tensor([[-1.5, 0., 2.], [3., 1.2, -2.4]]) + >>> out1 = paddle.logsumexp(x) + >>> out1 + Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, + 3.46912265) + >>> out2 = paddle.logsumexp(x, 1) + >>> out2 + Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True, + [2.15317822, 3.15684605]) + + """, + """ +def logsumexp( + x: Tensor, + axis: int | Sequence[int] | None = None, + keepdim: bool = False, + name: str | None = None, + *, + out: Tensor | None = None, +) -> Tensor + """, +) + # zhengsheng add_doc_and_signature( "isfinite", diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 9497a2eb3a477a..112c56a9e2b9c8 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -30,6 +30,7 @@ isfinite, isinf, isnan, + logsumexp, ) from paddle.base.libpaddle import DataType from paddle.common_ops_import import VarDesc, dygraph_utils @@ -3147,91 +3148,6 @@ def __check_input(x, y): return out -def logsumexp( - x: Tensor, - axis: int | Sequence[int] | None = None, - keepdim: bool = False, - name: str | None = None, -) -> Tensor: - r""" - Calculates the log of the sum of exponentials of ``x`` along ``axis`` . - - .. math:: - logsumexp(x) = \log\sum exp(x) - - Args: - x (Tensor): The input Tensor with data type bfloat16, float16, float32, - float64, uint8, int8, int16, int32, int64, which have no more than - 4 dimensions. - axis (int|list|tuple|None, optional): The axis along which to perform - logsumexp calculations. ``axis`` should be int, list(int) or - tuple(int). If ``axis`` is a list/tuple of dimension(s), logsumexp - is calculated along all element(s) of ``axis`` . ``axis`` or - element(s) of ``axis`` should be in range [-D, D), where D is the - dimensions of ``x`` . If ``axis`` or element(s) of ``axis`` is - less than 0, it works the same way as :math:`axis + D` . If - ``axis`` is None, logsumexp is calculated along all elements of - ``x``. Default is None. - keepdim (bool, optional): Whether to reserve the reduced dimension(s) - in the output Tensor. If ``keep_dim`` is True, the dimensions of - the output Tensor is the same as ``x`` except in the reduced - dimensions(it is of size 1 in this case). Otherwise, the shape of - the output Tensor is squeezed in ``axis`` . Default is False. - name (str|None, optional): Name for the operation (optional, default is None). - For more information, please refer to :ref:`api_guide_Name`. - - Returns: - Tensor, results of logsumexp along ``axis`` of ``x``, with the same data - type as ``x`` (integer types are autocasted into float32). - - Examples: - - .. code-block:: python - - >>> import paddle - - >>> x = paddle.to_tensor([[-1.5, 0., 2.], [3., 1.2, -2.4]]) - >>> out1 = paddle.logsumexp(x) - >>> out1 - Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, - 3.46912265) - >>> out2 = paddle.logsumexp(x, 1) - >>> out2 - Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True, - [2.15317822, 3.15684605]) - - """ - reduce_all, axis = _get_reduce_axis(axis, x) - - if in_dynamic_or_pir_mode(): - return _C_ops.logsumexp(x, axis, keepdim, reduce_all) - else: - check_variable_and_dtype( - x, - 'x', - [ - 'float16', - 'float32', - 'float64', - 'uint16', - 'uint8', - 'int8', - 'int16', - 'int32', - 'int64', - ], - 'logsumexp', - ) - - helper = LayerHelper('logsumexp', **locals()) - attrs = {'axis': axis, 'keepdim': keepdim, 'reduce_all': reduce_all} - out = helper.create_variable_for_type_inference(x.dtype) - helper.append_op( - type='logsumexp', inputs={'X': x}, outputs={'Out': out}, attrs=attrs - ) - return out - - def inverse(x: Tensor, name: str | None = None) -> Tensor: """ Takes the inverse of the square matrix. A square matrix is a matrix with diff --git a/test/legacy_test/test_logsumexp.py b/test/legacy_test/test_logsumexp.py index 7f4b34379040ef..836eea528f0317 100644 --- a/test/legacy_test/test_logsumexp.py +++ b/test/legacy_test/test_logsumexp.py @@ -43,8 +43,8 @@ def logsumexp_op_grad(x, axis=None, keepdim=False, reduce_all=False): tensor_x = paddle.to_tensor(x) tensor_x.stop_gradient = False out = logsumexp_wrapper(tensor_x, axis, keepdim, reduce_all) - grad = paddle.grad(out, [tensor_x]) - x_grad = grad[0].numpy() + out.backward() + x_grad = tensor_x.grad.numpy() paddle.enable_static() return x_grad @@ -276,7 +276,7 @@ def api_case(self, axis=None, keepdim=False): def test_api(self): self.api_case() - self.api_case(2) + self.api_case([2]) self.api_case([-1]) self.api_case([2, -3]) self.api_case((0, 1, -1)) @@ -340,5 +340,82 @@ def set_attrs(self): self.axis = [1] # out return shape [2, 0] +class TestLogsumexpAPI_Compatibility(unittest.TestCase): + def setUp(self): + np.random.seed(123) + paddle.enable_static() + self.shape = [5, 6] + self.dtype = 'float32' + self.init_data() + + def init_data(self): + self.np_input = np.random.randint(0, 8, self.shape).astype(self.dtype) + self.np_ref_out = ref_logsumexp( + self.np_input, axis=[0, 1], keepdim=True, reduce_all=True + ) + + def test_dygraph_Compatibility(self): + paddle.disable_static() + x = paddle.to_tensor(self.np_input) + paddle_dygraph_out = [] + # Position args (args) + out1 = paddle.logsumexp(x, [0, 1], True) + paddle_dygraph_out.append(out1) + # Key words args (kwargs) for paddle + out2 = paddle.logsumexp(x=x, axis=[0, 1], keepdim=True) + paddle_dygraph_out.append(out2) + # Key words args for torch + out3 = paddle.logsumexp(input=x, dim=[0, 1], keepdim=True) + paddle_dygraph_out.append(out3) + # Combined args and kwargs + out4 = paddle.logsumexp(x, dim=[0, 1], keepdim=True) + paddle_dygraph_out.append(out4) + # Tensor method args + out5 = x.logsumexp([0, 1], True) + paddle_dygraph_out.append(out5) + # Tensor method kwargs + out6 = x.logsumexp(dim=[0, 1], keepdim=True) + paddle_dygraph_out.append(out6) + # Test out + out7 = paddle.empty([]) + paddle.logsumexp(x, [0, 1], True, out=out7) + paddle_dygraph_out.append(out7) + # Numpy reference out + ref_out = self.np_ref_out + # Check + for out in paddle_dygraph_out: + np.testing.assert_allclose(ref_out, out.numpy()) + paddle.enable_static() + + def test_static_Compatibility(self): + main = paddle.static.Program() + startup = paddle.static.Program() + with paddle.base.program_guard(main, startup): + x = paddle.static.data(name="x", shape=self.shape, dtype=self.dtype) + # Position args (args) + out1 = paddle.logsumexp(x, [0, 1], True) + # Key words args (kwargs) for paddle + out2 = paddle.logsumexp(x=x, axis=[0, 1], keepdim=True) + # Key words args for torch + out3 = paddle.logsumexp(input=x, dim=[0, 1], keepdim=True) + # Combined args and kwargs + out4 = paddle.logsumexp(x, dim=[0, 1], keepdim=True) + # Tensor method args + out5 = x.logsumexp([0, 1], True) + # Tensor method kwargs + out6 = x.logsumexp(dim=[0, 1], keepdim=True) + # Do not support out in static + # out7 = paddle.empty([]) + exe = paddle.base.Executor(paddle.CPUPlace()) + fetches = exe.run( + main, + feed={"x": self.np_input}, + fetch_list=[out1, out2, out3, out4, out5, out6], + ) + ref_out = self.np_ref_out + for out in fetches: + np.testing.assert_allclose(out, ref_out) + + if __name__ == '__main__': unittest.main() From 50c8150ecb83cb0d8e608ffece98fd300f64971b Mon Sep 17 00:00:00 2001 From: DanielSun11 <1395924413@qq.com> Date: Sun, 24 Aug 2025 23:17:00 +0800 Subject: [PATCH 2/3] fix unit test --- test/legacy_test/test_logsumexp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/legacy_test/test_logsumexp.py b/test/legacy_test/test_logsumexp.py index 836eea528f0317..9fd9c0d68eda92 100644 --- a/test/legacy_test/test_logsumexp.py +++ b/test/legacy_test/test_logsumexp.py @@ -43,8 +43,8 @@ def logsumexp_op_grad(x, axis=None, keepdim=False, reduce_all=False): tensor_x = paddle.to_tensor(x) tensor_x.stop_gradient = False out = logsumexp_wrapper(tensor_x, axis, keepdim, reduce_all) - out.backward() - x_grad = tensor_x.grad.numpy() + grad = paddle.grad(out, [tensor_x]) + x_grad = grad[0].numpy() paddle.enable_static() return x_grad From 172f0e0fcadb9cf29ca5bc4259c8b20f712bdce2 Mon Sep 17 00:00:00 2001 From: DanielSun11 <1395924413@qq.com> Date: Sun, 24 Aug 2025 23:18:18 +0800 Subject: [PATCH 3/3] fix unit test --- test/legacy_test/test_logsumexp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/legacy_test/test_logsumexp.py b/test/legacy_test/test_logsumexp.py index 9fd9c0d68eda92..e6529dea4c8f4e 100644 --- a/test/legacy_test/test_logsumexp.py +++ b/test/legacy_test/test_logsumexp.py @@ -276,7 +276,7 @@ def api_case(self, axis=None, keepdim=False): def test_api(self): self.api_case() - self.api_case([2]) + self.api_case(2) self.api_case([-1]) self.api_case([2, -3]) self.api_case((0, 1, -1))