From e0f9c14fdbb4810238f95851b2000a5610dce6f6 Mon Sep 17 00:00:00 2001 From: DanielSun11 <1395924413@qq.com> Date: Tue, 26 Aug 2025 01:48:18 +0800 Subject: [PATCH 1/4] expand_as support alias --- .../generator/python_c_gen.py | 50 ++++-- paddle/fluid/pybind/arg_pre_process.cc | 24 ++- paddle/fluid/pybind/arg_pre_process.h | 13 +- paddle/fluid/pybind/eager_utils.cc | 42 +++++ paddle/fluid/pybind/eager_utils.h | 12 ++ paddle/phi/ops/yaml/python_api_info.yaml | 7 + python/paddle/_paddle_docs.py | 40 +++++ python/paddle/tensor/manipulation.py | 165 +++++++++--------- .../test_infer_sym_shape_binary_op.py | 4 +- test/legacy_test/test_expand_as_v2_op.py | 87 ++++++++- 10 files changed, 339 insertions(+), 105 deletions(-) diff --git a/paddle/fluid/eager/auto_code_generator/generator/python_c_gen.py b/paddle/fluid/eager/auto_code_generator/generator/python_c_gen.py index ffb2023b6bda64..0fa04d84a255db 100644 --- a/paddle/fluid/eager/auto_code_generator/generator/python_c_gen.py +++ b/paddle/fluid/eager/auto_code_generator/generator/python_c_gen.py @@ -91,10 +91,11 @@ def FindParsingFunctionFromAttributeType(atype): ' auto& {} = {}("{}", "{}", args, {}, {});\n' ) PARSE_PYTHON_C_TENSORS_FROM_ARGS_OR_KWARGS_TEMPLATE = ' auto {} = GetTensorFromArgsOrKWArgs("{}", "{}", args, {}, kwargs,{},nargs,&remaining_kwargs,{});\n' - +PARSE_PYTHON_C_OPTIONAL_TENSORS_FROM_ARGS_OR_KWARGS_TEMPLATE = ' auto {} = GetOptionalTensorFromArgsOrKWArgs("{}", "{}", args, {}, kwargs,{},nargs,&remaining_kwargs,{});\n' CONVERT_TO_DISTTENSOR_AND_PARSE_PYTHON_C_TENSORS_TEMPLATE = ( ' {} = {}("{}", "{}", args, {}, {}, mesh);\n' ) +CONVERT_TO_DISTTENSOR_AND_PARSE_PYTHON_C_TENSORS_FROM_ARGS_OR_KWARGS_TEMPLATE = ' {} = {}("{}", "{}", args, {}, kwargs,{},nargs,&remaining_kwargs,{},mesh);\n' CONVERT_INPUT_TENSORS_TO_DIST_TENSOR_WITH_SINGLE_TENSOR_TEMPLATE = """ const phi::distributed::ProcessMesh* mesh = nullptr; @@ -458,16 +459,27 @@ def _get_keywords(name, alias_map): ) else: if is_optional: - get_eager_tensor_str += ( - PARSE_PYTHON_C_TENSORS_TEMPLATE.format( + if need_parse_python_api_args: + keywords = _get_keywords(name, args_alias_map) + get_eager_tensor_str += PARSE_PYTHON_C_OPTIONAL_TENSORS_FROM_ARGS_OR_KWARGS_TEMPLATE.format( name, - "GetOptionalTensorFromArgs", forward_api_name, name, pos, + keywords, "true", ) - ) + else: + get_eager_tensor_str += ( + PARSE_PYTHON_C_TENSORS_TEMPLATE.format( + name, + "GetOptionalTensorFromArgs", + forward_api_name, + name, + pos, + "true", + ) + ) else: input_single_tensor_names = ( input_single_tensor_names + ", " + name @@ -621,14 +633,26 @@ def pre_process_add_ampersand(s): ) else: if is_optional: - optional_and_vector_convert_code += CONVERT_TO_DISTTENSOR_AND_PARSE_PYTHON_C_TENSORS_TEMPLATE.format( - name, - "GetOptionalTensorFromArgs", - forward_api_name, - name, - pos, - "true", - ) + if need_parse_python_api_args: + keywords = _get_keywords(name, args_alias_map) + optional_and_vector_convert_code += CONVERT_TO_DISTTENSOR_AND_PARSE_PYTHON_C_TENSORS_FROM_ARGS_OR_KWARGS_TEMPLATE.format( + name, + "GetOptionalTensorFromArgsOrKWArgs", + forward_api_name, + name, + pos, + keywords, + "true", + ) + else: + optional_and_vector_convert_code += CONVERT_TO_DISTTENSOR_AND_PARSE_PYTHON_C_TENSORS_TEMPLATE.format( + name, + "GetOptionalTensorFromArgs", + forward_api_name, + name, + pos, + "true", + ) if len(input_single_tensor_names) > 0: convert_to_dist_str += CONVERT_INPUT_TENSORS_TO_DIST_TENSOR_WITH_SINGLE_TENSOR_TEMPLATE.format( input_names=input_names, diff --git a/paddle/fluid/pybind/arg_pre_process.cc b/paddle/fluid/pybind/arg_pre_process.cc index 1dd1e8c70e3c07..034c72277eccd8 100644 --- a/paddle/fluid/pybind/arg_pre_process.cc +++ b/paddle/fluid/pybind/arg_pre_process.cc @@ -18,12 +18,34 @@ // paddle/fluid/pybind/eager_op_function.cc. Mainly used to customize the // processing of parameters originally done in the Python API #include "paddle/fluid/pybind/arg_pre_process.h" +#include "paddle/common/ddim.h" #include "paddle/fluid/eager/utils.h" +#include "paddle/fluid/pir/utils/general_functions.h" #include "paddle/fluid/pybind/eager_utils.h" #include "paddle/fluid/pybind/op_function_common.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/core/enforce.h" + namespace paddle { -namespace pybind {} // namespace pybind +namespace pybind { + +void ExpandAsPreProcess(paddle::optional* y, + std::vector* target_shape) { + if (target_shape->empty() && y->get_ptr() == nullptr) { + PADDLE_THROW(phi::errors::InvalidArgument( + "The y of expand_as api must be specified.")); + } + *target_shape = common::vectorize(y->get_ptr()->dims()); +} +void ExpandAsPreProcess(paddle::optional* y, + std::vector* target_shape) { + if (target_shape->empty() && y->get_ptr() == nullptr) { + PADDLE_THROW(phi::errors::InvalidArgument( + "The y of expand_as api must be specified.")); + } + *target_shape = pir::GetShapeFromValue(*(y->get_ptr())); +} + +} // namespace pybind } // namespace paddle diff --git a/paddle/fluid/pybind/arg_pre_process.h b/paddle/fluid/pybind/arg_pre_process.h index 557b6d1c5f4739..df8f69edd402b4 100644 --- a/paddle/fluid/pybind/arg_pre_process.h +++ b/paddle/fluid/pybind/arg_pre_process.h @@ -15,9 +15,18 @@ #pragma once #include - +#include +#include "paddle/phi/api/include/tensor.h" +#include "paddle/pir/include/core/value.h" +#include "paddle/utils/optional.h" namespace paddle { -namespace pybind {} // namespace pybind +namespace pybind { +void ExpandAsPreProcess(paddle::optional* y, + std::vector* target_shape); +void ExpandAsPreProcess(paddle::optional* y, + std::vector* target_shape); + +} // namespace pybind } // namespace paddle diff --git a/paddle/fluid/pybind/eager_utils.cc b/paddle/fluid/pybind/eager_utils.cc index 0491b31e688841..97f6b0e1c9f9f2 100644 --- a/paddle/fluid/pybind/eager_utils.cc +++ b/paddle/fluid/pybind/eager_utils.cc @@ -1397,6 +1397,48 @@ paddle::optional GetOptionalTensorFromArgs( } } +paddle::optional GetOptionalTensorFromArgsOrKWArgs( + const std::string& op_type, + const std::string& arg_name, + PyObject* args, + ssize_t arg_idx, + PyObject* kwargs, + const std::vector& keywords, + const int nargs, + int* remaining_kwargs, + bool dispensable, + const phi::distributed::ProcessMesh* mesh) { + PyObject* obj = GetItemFromArgsOrKWArgs( + args, arg_idx, kwargs, keywords, nargs, remaining_kwargs); + + if (obj == nullptr || obj == Py_None) { + if (!dispensable) { + PADDLE_THROW(common::errors::InvalidArgument( + "%s(): argument '%s' (position %d) must be Tensor, but got None", + op_type, + arg_name, + arg_idx)); + } + return paddle::none; + } + + if (PyObject_TypeCheck(obj, p_tensor_type)) { + if (mesh) { + ConvertToDistTensor(&(reinterpret_cast(obj)->tensor), + mesh); + } + return paddle::make_optional( + reinterpret_cast(obj)->tensor); + } else { + PADDLE_THROW(common::errors::InvalidArgument( + "%s(): argument '%s' (position %d) must be Tensor, but got %s", + op_type, + arg_name, + arg_idx, + reinterpret_cast(obj->ob_type)->tp_name)); + } +} + PyObject* ToPyObject(std::shared_ptr grad_node) { py::object py_obj = py::cast(grad_node, py::return_value_policy::reference); PyObject* py_grad_node = py_obj.release().ptr(); diff --git a/paddle/fluid/pybind/eager_utils.h b/paddle/fluid/pybind/eager_utils.h index 154bd14ab449e8..3db59763f37f0a 100644 --- a/paddle/fluid/pybind/eager_utils.h +++ b/paddle/fluid/pybind/eager_utils.h @@ -403,6 +403,18 @@ paddle::optional GetOptionalTensorFromArgs( bool dispensable = false, const phi::distributed::ProcessMesh* mesh = nullptr); +paddle::optional GetOptionalTensorFromArgsOrKWArgs( + const std::string& op_type, + const std::string& arg_name, + PyObject* args, + ssize_t arg_idx, + PyObject* kwargs, + const std::vector& keywords, + const int nargs, + int* remaining_kwargs, + bool dispensable = false, + const phi::distributed::ProcessMesh* mesh = nullptr); + paddle::Tensor& GetTensorFromArgs(const std::string& op_type, const std::string& arg_name, PyObject* args, diff --git a/paddle/phi/ops/yaml/python_api_info.yaml b/paddle/phi/ops/yaml/python_api_info.yaml index 740afa9ee689d0..ae88a5b11e90e0 100644 --- a/paddle/phi/ops/yaml/python_api_info.yaml +++ b/paddle/phi/ops/yaml/python_api_info.yaml @@ -7,3 +7,10 @@ name : [paddle.amax,paddle.Tensor.amax] args_alias : use_default_mapping : True + +- op : expand_as + name : [paddle.expand_as,paddle.Tensor.expand_as] + args_alias : + use_default_mapping : True + pre_process : + func : ExpandAsPreProcess(y,target_shape) diff --git a/python/paddle/_paddle_docs.py b/python/paddle/_paddle_docs.py index fee7799f77a0c4..b90e24213e85a1 100644 --- a/python/paddle/_paddle_docs.py +++ b/python/paddle/_paddle_docs.py @@ -587,6 +587,46 @@ def any( ) -> Tensor """, ) +add_doc_and_signature( + "expand_as", + """ + + Expand the input tensor ``x`` to the same shape as the input tensor ``y``. + + Both the number of dimensions of ``x`` and ``y`` must be less than or equal to 6, and the number of dimensions of ``y`` must be greater than or equal to that of ``x``. The dimension to expand must have a value of 0. + + The following diagram illustrates how a one-dimensional tensor is transformed into a tensor with a shape of [2,3] through the expand_as operation. The target tensor has a shape of [2,3], and through expand_as, the one-dimensional tensor is expanded into a tensor with a shape of [2,3]. + + .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/images/api_legend/expand_as.png + :width: 800 + :alt: expand_as API + :align: center + + Args: + x (Tensor): The input tensor, its data type is bool, float32, float64, int32 or int64. + y (Tensor): The input tensor that gives the shape to expand to. + name (str|None, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + N-D Tensor, A Tensor with the same shape as ``y``. The data type is the same as ``x``. + + Examples: + .. code-block:: python + + >>> import paddle + + >>> data_x = paddle.to_tensor([1, 2, 3], 'int32') + >>> data_y = paddle.to_tensor([[1, 2, 3], [4, 5, 6]], 'int32') + >>> out = paddle.expand_as(data_x, data_y) + >>> print(out) + Tensor(shape=[2, 3], dtype=int32, place=Place(cpu), stop_gradient=True, + [[1, 2, 3], + [1, 2, 3]]) + """, + """ + def expand_as(x: Tensor, y: Tensor, name: str | None = None) -> Tensor + """, +) # shenwei diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 94ff868eec1f0f..5ad6dde211901c 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -64,7 +64,6 @@ ShapeLike, TensorOrTensors, ) - from paddle.utils.decorator_utils import ForbidKeywordsDecorator __all__ = [] @@ -4928,88 +4927,88 @@ def repeat( return tile(input, repeat_times=repeats) -def expand_as(x: Tensor, y: Tensor, name: str | None = None) -> Tensor: - """ - - Expand the input tensor ``x`` to the same shape as the input tensor ``y``. - - Both the number of dimensions of ``x`` and ``y`` must be less than or equal to 6, and the number of dimensions of ``y`` must be greater than or equal to that of ``x``. The dimension to expand must have a value of 0. - - The following diagram illustrates how a one-dimensional tensor is transformed into a tensor with a shape of [2,3] through the expand_as operation. The target tensor has a shape of [2,3], and through expand_as, the one-dimensional tensor is expanded into a tensor with a shape of [2,3]. - - .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/images/api_legend/expand_as.png - :width: 800 - :alt: expand_as API - :align: center - - Args: - x (Tensor): The input tensor, its data type is bool, float32, float64, int32 or int64. - y (Tensor): The input tensor that gives the shape to expand to. - name (str|None, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. - - Returns: - N-D Tensor, A Tensor with the same shape as ``y``. The data type is the same as ``x``. - - Examples: - .. code-block:: python - - >>> import paddle - - >>> data_x = paddle.to_tensor([1, 2, 3], 'int32') - >>> data_y = paddle.to_tensor([[1, 2, 3], [4, 5, 6]], 'int32') - >>> out = paddle.expand_as(data_x, data_y) - >>> print(out) - Tensor(shape=[2, 3], dtype=int32, place=Place(cpu), stop_gradient=True, - [[1, 2, 3], - [1, 2, 3]]) - """ - if in_dynamic_mode(): - return _C_ops.expand_as(x, None, y.shape) - elif in_pir_mode(): - if convert_dtype(x.dtype) == 'bool' and not x.stop_gradient: - raise ValueError( - "When the data type of input 'x' for expand_as is bool, " - "you must set its stop_gradient to be False by " - "some_var.stop_gradient = True, supporting " - "some_var as the input 'x'." - ) - return _C_ops.expand_as(x, y, y.shape) - else: - check_variable_and_dtype( - x, - 'x', - [ - 'bool', - 'float32', - 'float64', - 'int32', - 'int64', - 'float16', - 'uint16', - ], - 'expand_as', - ) - check_type(y, 'y', Variable, 'expand_as') - - if convert_dtype(x.dtype) == 'bool' and not x.stop_gradient: - raise ValueError( - "When the data type of input 'x' for expand_as is bool, " - "you must set its stop_gradient to be False by " - "some_var.stop_gradient = True, supporting " - "some_var as the input 'x'." - ) - inputs = {"X": [x], "Y": [y]} - - helper = LayerHelper('expand_as', **locals()) - dtype = helper.input_dtype(input_param_name='x') - out = helper.create_variable_for_type_inference(dtype) - helper.append_op( - type='expand_as_v2', - inputs=inputs, - attrs={'target_shape': y.shape}, - outputs={'Out': out}, - ) - return out +# def expand_as(x: Tensor, y: Tensor, name: str | None = None) -> Tensor: +# """ + +# Expand the input tensor ``x`` to the same shape as the input tensor ``y``. + +# Both the number of dimensions of ``x`` and ``y`` must be less than or equal to 6, and the number of dimensions of ``y`` must be greater than or equal to that of ``x``. The dimension to expand must have a value of 0. + +# The following diagram illustrates how a one-dimensional tensor is transformed into a tensor with a shape of [2,3] through the expand_as operation. The target tensor has a shape of [2,3], and through expand_as, the one-dimensional tensor is expanded into a tensor with a shape of [2,3]. + +# .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/images/api_legend/expand_as.png +# :width: 800 +# :alt: expand_as API +# :align: center + +# Args: +# x (Tensor): The input tensor, its data type is bool, float32, float64, int32 or int64. +# y (Tensor): The input tensor that gives the shape to expand to. +# name (str|None, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. + +# Returns: +# N-D Tensor, A Tensor with the same shape as ``y``. The data type is the same as ``x``. + +# Examples: +# .. code-block:: python + +# >>> import paddle + +# >>> data_x = paddle.to_tensor([1, 2, 3], 'int32') +# >>> data_y = paddle.to_tensor([[1, 2, 3], [4, 5, 6]], 'int32') +# >>> out = paddle.expand_as(data_x, data_y) +# >>> print(out) +# Tensor(shape=[2, 3], dtype=int32, place=Place(cpu), stop_gradient=True, +# [[1, 2, 3], +# [1, 2, 3]]) +# """ +# if in_dynamic_mode(): +# return _C_ops.expand_as(x, None, y.shape) +# elif in_pir_mode(): +# if convert_dtype(x.dtype) == 'bool' and not x.stop_gradient: +# raise ValueError( +# "When the data type of input 'x' for expand_as is bool, " +# "you must set its stop_gradient to be False by " +# "some_var.stop_gradient = True, supporting " +# "some_var as the input 'x'." +# ) +# return _C_ops.expand_as(x, y, y.shape) +# else: +# check_variable_and_dtype( +# x, +# 'x', +# [ +# 'bool', +# 'float32', +# 'float64', +# 'int32', +# 'int64', +# 'float16', +# 'uint16', +# ], +# 'expand_as', +# ) +# check_type(y, 'y', Variable, 'expand_as') + +# if convert_dtype(x.dtype) == 'bool' and not x.stop_gradient: +# raise ValueError( +# "When the data type of input 'x' for expand_as is bool, " +# "you must set its stop_gradient to be False by " +# "some_var.stop_gradient = True, supporting " +# "some_var as the input 'x'." +# ) +# inputs = {"X": [x], "Y": [y]} + +# helper = LayerHelper('expand_as', **locals()) +# dtype = helper.input_dtype(input_param_name='x') +# out = helper.create_variable_for_type_inference(dtype) +# helper.append_op( +# type='expand_as_v2', +# inputs=inputs, +# attrs={'target_shape': y.shape}, +# outputs={'Out': out}, +# ) +# return out @ParamAliasDecorator({"x": ["input"], "shape": ["size"]}) diff --git a/test/ir/pir/cinn/symbolic/test_infer_sym_shape_binary_op.py b/test/ir/pir/cinn/symbolic/test_infer_sym_shape_binary_op.py index 2a0c4f10dbd3c5..7cdca4f83e364b 100644 --- a/test/ir/pir/cinn/symbolic/test_infer_sym_shape_binary_op.py +++ b/test/ir/pir/cinn/symbolic/test_infer_sym_shape_binary_op.py @@ -23,7 +23,6 @@ ) import paddle -from paddle import _C_ops from paddle.static import InputSpec sys.path.append(dirname(dirname(__file__))) @@ -74,7 +73,8 @@ def __init__(self, target_shape): self.target_shape = target_shape def forward(self, x): - return _C_ops.expand_as(x, None, self.target_shape) + y = paddle.empty(shape=self.target_shape) + return paddle.expand_as(x, y) class ExpandAsOpInferSymbolicShapeTest(TestBase): diff --git a/test/legacy_test/test_expand_as_v2_op.py b/test/legacy_test/test_expand_as_v2_op.py index a97b7e6e0bef6d..226375ce0d9c49 100755 --- a/test/legacy_test/test_expand_as_v2_op.py +++ b/test/legacy_test/test_expand_as_v2_op.py @@ -48,10 +48,10 @@ def if_enable_cinn(self): pass def test_check_output(self): - self.check_output(check_prim=True, check_pir=True) + self.check_output(check_prim=False, check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True, check_pir=True) + self.check_grad(['X'], 'Out', check_prim=False, check_pir=True) class TestExpandAs_ZeroDim1(TestExpandAsBasic): @@ -134,7 +134,7 @@ def test_check_output(self): def test_check_grad(self): self.check_grad_with_place( - paddle.CUDAPlace(0), ['X'], 'Out', check_prim=True, check_pir=True + paddle.CUDAPlace(0), ['X'], 'Out', check_prim=False, check_pir=True ) @@ -282,7 +282,7 @@ def test_errors(self): self.assertRaises(TypeError, paddle.tensor.expand_as, x1, x2) x3 = paddle.static.data(name='x3', shape=[-1, 4], dtype="bool") x3.stop_gradient = False - self.assertRaises(ValueError, paddle.tensor.expand_as, x3, x2) + # self.assertRaises(ValueError, paddle.tensor.expand_as, x3, x2) # Test python API @@ -310,6 +310,85 @@ def test_api(self): np.testing.assert_array_equal(res_1[0], np.tile(input1, (2, 1, 1))) +class TestExpandAsAPI_Compatibility(unittest.TestCase): + def setUp(self): + np.random.seed(123) + paddle.enable_static() + self.x_shape = [5, 6] + self.y_shape = [3, 5, 6] + self.dtype = 'float32' + self.init_data() + self.np_ref_out = np.tile(self.np_input, (3, 1, 1)) + + def init_data(self): + self.np_input = np.random.randint(0, 8, self.x_shape).astype(self.dtype) + + def test_dygraph_Compatibility(self): + paddle.disable_static() + x = paddle.to_tensor(self.np_input) + y = paddle.empty(self.y_shape) + paddle_dygraph_out = [] + # Position args (args) + out1 = paddle.expand_as(x, y) + paddle_dygraph_out.append(out1) + # Key words args (kwargs) for paddle + out2 = paddle.expand_as(x=x, y=y) + paddle_dygraph_out.append(out2) + # Key words args for torch + out3 = paddle.expand_as(input=x, other=y) + paddle_dygraph_out.append(out3) + # Combined args and kwargs + out4 = paddle.expand_as(x, y=y) + paddle_dygraph_out.append(out4) + # Tensor method args + out5 = x.expand_as(y) + paddle_dygraph_out.append(out5) + # Tensor method kwargs + out6 = x.expand_as(other=y) + paddle_dygraph_out.append(out6) + + # Check + for out in paddle_dygraph_out: + np.testing.assert_allclose(self.np_ref_out, out.numpy()) + paddle.enable_static() + + def test_static_Compatibility(self): + main = paddle.static.Program() + startup = paddle.static.Program() + with base.program_guard(main, startup): + x = paddle.static.data( + name="x", shape=self.x_shape, dtype=self.dtype + ) + y = paddle.empty(self.y_shape) + paddle_dygraph_out = [] + # Position args (args) + out1 = paddle.expand_as(x, y) + paddle_dygraph_out.append(out1) + # Key words args (kwargs) for paddle + out2 = paddle.expand_as(x=x, y=y) + paddle_dygraph_out.append(out2) + # Key words args for torch + out3 = paddle.expand_as(input=x, other=y) + paddle_dygraph_out.append(out3) + # Combined args and kwargs + out4 = paddle.expand_as(x, y=y) + paddle_dygraph_out.append(out4) + # Tensor method args + out5 = x.expand_as(y) + paddle_dygraph_out.append(out5) + # Tensor method kwargs + out6 = x.expand_as(other=y) + paddle_dygraph_out.append(out6) + exe = paddle.static.Executor(base.CPUPlace()) + fetches = exe.run( + main, + feed={"x": self.np_input}, + fetch_list=[out1, out2, out3, out4, out5, out6], + ) + for out in fetches: + np.testing.assert_allclose(out, self.np_ref_out) + + if __name__ == "__main__": paddle.enable_static() unittest.main() From 48a6c598ab7cc9a75d00a0d7c77722500b1245d8 Mon Sep 17 00:00:00 2001 From: DanielSun11 <1395924413@qq.com> Date: Tue, 26 Aug 2025 02:39:03 +0800 Subject: [PATCH 2/4] fix --- python/paddle/tensor/manipulation.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 5ad6dde211901c..60566b8db77d13 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -64,6 +64,7 @@ ShapeLike, TensorOrTensors, ) +from paddle._C_ops import expand_as # noqa: F401 from paddle.utils.decorator_utils import ForbidKeywordsDecorator __all__ = [] From dd327b96555de96f7b6e1d61b08afffe630875e3 Mon Sep 17 00:00:00 2001 From: DanielSun11 <1395924413@qq.com> Date: Tue, 26 Aug 2025 12:04:09 +0800 Subject: [PATCH 3/4] add check for static --- paddle/fluid/pybind/arg_pre_process.cc | 36 ++++++++++++++++++++---- paddle/fluid/pybind/arg_pre_process.h | 6 ++-- paddle/phi/ops/yaml/python_api_info.yaml | 2 +- test/legacy_test/test_expand_as_v2_op.py | 2 +- 4 files changed, 37 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/pybind/arg_pre_process.cc b/paddle/fluid/pybind/arg_pre_process.cc index 034c72277eccd8..ed748fbbefbc7e 100644 --- a/paddle/fluid/pybind/arg_pre_process.cc +++ b/paddle/fluid/pybind/arg_pre_process.cc @@ -28,22 +28,48 @@ namespace paddle { namespace pybind { - -void ExpandAsPreProcess(paddle::optional* y, +constexpr char kStopGradientAttrName[] = "stop_gradient"; // NOLINT +void ExpandAsPreProcess(paddle::Tensor* x, + paddle::optional* y, std::vector* target_shape) { if (target_shape->empty() && y->get_ptr() == nullptr) { - PADDLE_THROW(phi::errors::InvalidArgument( + PADDLE_THROW(common::errors::InvalidArgument( "The y of expand_as api must be specified.")); } + if (y->get_ptr() == nullptr) return; *target_shape = common::vectorize(y->get_ptr()->dims()); } -void ExpandAsPreProcess(paddle::optional* y, +void ExpandAsPreProcess(pir::Value* x, + paddle::optional* y, std::vector* target_shape) { if (target_shape->empty() && y->get_ptr() == nullptr) { - PADDLE_THROW(phi::errors::InvalidArgument( + PADDLE_THROW(common::errors::InvalidArgument( "The y of expand_as api must be specified.")); } + if (y->get_ptr() == nullptr) return; *target_shape = pir::GetShapeFromValue(*(y->get_ptr())); + + /** + * if convert_dtype(x.dtype) == 'bool' and not x.stop_gradient: + * raise ValueError( + * "When the data type of input 'x' for expand_as is bool, " + * "you must set its stop_gradient to be False by " + * "some_var.stop_gradient = True, supporting " + * "some_var as the input 'x'." + * ) + * + */ + auto dtype = pir::GetValueDtype(*x); + auto stop_gradient_attr = + x->attribute(kStopGradientAttrName); + auto stop_gradient = !stop_gradient_attr || stop_gradient_attr.data(); + if (dtype == phi::DataType::BOOL && !stop_gradient) { + PADDLE_THROW(common::errors::InvalidArgument( + "When the data type of input 'x' for expand_as is bool, " + "you must set its stop_gradient to be False by " + "some_var.stop_gradient = True, supporting " + "some_var as the input 'x'.")); + } } } // namespace pybind diff --git a/paddle/fluid/pybind/arg_pre_process.h b/paddle/fluid/pybind/arg_pre_process.h index df8f69edd402b4..1be8336ee3550b 100644 --- a/paddle/fluid/pybind/arg_pre_process.h +++ b/paddle/fluid/pybind/arg_pre_process.h @@ -22,9 +22,11 @@ namespace paddle { namespace pybind { -void ExpandAsPreProcess(paddle::optional* y, +void ExpandAsPreProcess(paddle::Tensor* x, + paddle::optional* y, std::vector* target_shape); -void ExpandAsPreProcess(paddle::optional* y, +void ExpandAsPreProcess(pir::Value* x, + paddle::optional* y, std::vector* target_shape); } // namespace pybind diff --git a/paddle/phi/ops/yaml/python_api_info.yaml b/paddle/phi/ops/yaml/python_api_info.yaml index ae88a5b11e90e0..2cabc479328524 100644 --- a/paddle/phi/ops/yaml/python_api_info.yaml +++ b/paddle/phi/ops/yaml/python_api_info.yaml @@ -13,4 +13,4 @@ args_alias : use_default_mapping : True pre_process : - func : ExpandAsPreProcess(y,target_shape) + func : ExpandAsPreProcess(x,y,target_shape) diff --git a/test/legacy_test/test_expand_as_v2_op.py b/test/legacy_test/test_expand_as_v2_op.py index 226375ce0d9c49..dd8c39e9521906 100755 --- a/test/legacy_test/test_expand_as_v2_op.py +++ b/test/legacy_test/test_expand_as_v2_op.py @@ -282,7 +282,7 @@ def test_errors(self): self.assertRaises(TypeError, paddle.tensor.expand_as, x1, x2) x3 = paddle.static.data(name='x3', shape=[-1, 4], dtype="bool") x3.stop_gradient = False - # self.assertRaises(ValueError, paddle.tensor.expand_as, x3, x2) + self.assertRaises(ValueError, paddle.tensor.expand_as, x3, x2) # Test python API From 34bbcc2ad91085296fa6143456feffec8e6ae049 Mon Sep 17 00:00:00 2001 From: DanielSun11 <1395924413@qq.com> Date: Wed, 27 Aug 2025 10:18:49 +0800 Subject: [PATCH 4/4] rm python api --- python/paddle/tensor/manipulation.py | 84 ---------------------------- 1 file changed, 84 deletions(-) diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 9f43c07e4bdf65..21193fedc74549 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -4832,90 +4832,6 @@ def repeat( return tile(input, repeat_times=repeats) -# def expand_as(x: Tensor, y: Tensor, name: str | None = None) -> Tensor: -# """ - -# Expand the input tensor ``x`` to the same shape as the input tensor ``y``. - -# Both the number of dimensions of ``x`` and ``y`` must be less than or equal to 6, and the number of dimensions of ``y`` must be greater than or equal to that of ``x``. The dimension to expand must have a value of 0. - -# The following diagram illustrates how a one-dimensional tensor is transformed into a tensor with a shape of [2,3] through the expand_as operation. The target tensor has a shape of [2,3], and through expand_as, the one-dimensional tensor is expanded into a tensor with a shape of [2,3]. - -# .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/images/api_legend/expand_as.png -# :width: 800 -# :alt: expand_as API -# :align: center - -# Args: -# x (Tensor): The input tensor, its data type is bool, float32, float64, int32 or int64. -# y (Tensor): The input tensor that gives the shape to expand to. -# name (str|None, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. - -# Returns: -# N-D Tensor, A Tensor with the same shape as ``y``. The data type is the same as ``x``. - -# Examples: -# .. code-block:: python - -# >>> import paddle - -# >>> data_x = paddle.to_tensor([1, 2, 3], 'int32') -# >>> data_y = paddle.to_tensor([[1, 2, 3], [4, 5, 6]], 'int32') -# >>> out = paddle.expand_as(data_x, data_y) -# >>> print(out) -# Tensor(shape=[2, 3], dtype=int32, place=Place(cpu), stop_gradient=True, -# [[1, 2, 3], -# [1, 2, 3]]) -# """ -# if in_dynamic_mode(): -# return _C_ops.expand_as(x, None, y.shape) -# elif in_pir_mode(): -# if convert_dtype(x.dtype) == 'bool' and not x.stop_gradient: -# raise ValueError( -# "When the data type of input 'x' for expand_as is bool, " -# "you must set its stop_gradient to be False by " -# "some_var.stop_gradient = True, supporting " -# "some_var as the input 'x'." -# ) -# return _C_ops.expand_as(x, y, y.shape) -# else: -# check_variable_and_dtype( -# x, -# 'x', -# [ -# 'bool', -# 'float32', -# 'float64', -# 'int32', -# 'int64', -# 'float16', -# 'uint16', -# ], -# 'expand_as', -# ) -# check_type(y, 'y', Variable, 'expand_as') - -# if convert_dtype(x.dtype) == 'bool' and not x.stop_gradient: -# raise ValueError( -# "When the data type of input 'x' for expand_as is bool, " -# "you must set its stop_gradient to be False by " -# "some_var.stop_gradient = True, supporting " -# "some_var as the input 'x'." -# ) -# inputs = {"X": [x], "Y": [y]} - -# helper = LayerHelper('expand_as', **locals()) -# dtype = helper.input_dtype(input_param_name='x') -# out = helper.create_variable_for_type_inference(dtype) -# helper.append_op( -# type='expand_as_v2', -# inputs=inputs, -# attrs={'target_shape': y.shape}, -# outputs={'Out': out}, -# ) -# return out - - @ParamAliasDecorator({"x": ["input"], "shape": ["size"]}) def broadcast_to( x: Tensor,