From 0e164caacbae90ba892655d28f6f936ccc0217ba Mon Sep 17 00:00:00 2001 From: zhanghonggeng Date: Tue, 22 Jul 2025 09:01:39 +0000 Subject: [PATCH 01/31] fix index_elemwentwise_get_gard bug slice-check --- paddle/phi/kernels/cpu/index_elementwise_get_grad_kernel.cc | 4 ++-- paddle/phi/kernels/gpu/index_elementwise_get_grad_kernel.cu | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/paddle/phi/kernels/cpu/index_elementwise_get_grad_kernel.cc b/paddle/phi/kernels/cpu/index_elementwise_get_grad_kernel.cc index 65f0073df4b71c..ccb4d518fceeb8 100644 --- a/paddle/phi/kernels/cpu/index_elementwise_get_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/index_elementwise_get_grad_kernel.cc @@ -76,8 +76,8 @@ void CPUIndexElementwiseGetGrad(const phi::CPUContext& dev_ctx, funcs::IndexPutStride<3>(input_dims, input_strides, phi::SizeOf(input.dtype()), - std::vector(), - std::vector(), + common::vectorize(value.dims()), + common::vectorize(value.strides()), phi::SizeOf(value.dtype()), shape_tmp, stride_tmp, diff --git a/paddle/phi/kernels/gpu/index_elementwise_get_grad_kernel.cu b/paddle/phi/kernels/gpu/index_elementwise_get_grad_kernel.cu index 7e443f87de37e6..135f01834c2845 100644 --- a/paddle/phi/kernels/gpu/index_elementwise_get_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/index_elementwise_get_grad_kernel.cu @@ -92,8 +92,8 @@ void GPUIndexElementwiseGetGrad(const phi::GPUContext& dev_ctx, funcs::IndexPutStride<3>(input_dims, input_strides, phi::SizeOf(input.dtype()), - std::vector(), - std::vector(), + common::vectorize(value.dims()), + common::vectorize(value.strides()), phi::SizeOf(value.dtype()), shape_tmp, stride_tmp, From 7513650571ae839a5738176f7b20fcea23cb5982 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Mon, 4 Aug 2025 20:27:24 +0800 Subject: [PATCH 02/31] enhance Tensor creation methods --- python/paddle/tensor/creation.py | 198 +++++++++++++++++++++++++----- test/legacy_test/test_creation.py | 171 ++++++++++++++++++++++++++ 2 files changed, 337 insertions(+), 32 deletions(-) create mode 100644 test/legacy_test/test_creation.py diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index 5bb5254c5c7c37..579aefc6891a10 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -1039,6 +1039,9 @@ def full_like( fill_value: bool | float, dtype: DTypeLike | None = None, name: str | None = None, + *, + device: PlaceLike | None = None, + requires_grad: bool = False, ) -> paddle.Tensor: """ @@ -1052,6 +1055,10 @@ def full_like( of bool, float16, float32, float64, int32, int64. The default value is None, which means the output data type is the same as input. name(str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. + device(PlaceLike|None, optional): The desired device of returned tensor. + if None, uses the current device for the default tensor type (see paddle.device.set_device()). + device will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types. Default: None. + requires_grad(bool, optional): If autograd should record operations on the returned tensor. Default: False. Returns: Tensor: Tensor which is created according to ``x``, ``fill_value`` and ``dtype``. @@ -1073,11 +1080,19 @@ def full_like( else: if not isinstance(dtype, (core.VarDesc.VarType, core.DataType)): dtype = convert_np_dtype_to_dtype_(dtype) + if requires_grad is None: + requires_grad = not x.stop_gradient + if device is None: + device = x.place - if in_dynamic_mode(): - return _C_ops.full_like(x, fill_value, dtype, x.place) - elif in_pir_mode(): - return _C_ops.full_like(x, fill_value, dtype, core.Place()) + if in_dynamic_or_pir_mode(): + if in_dynamic_mode(): + tensor = _C_ops.full_like(x, fill_value, dtype, x.place) + else: + tensor = _C_ops.full_like(x, fill_value, dtype, core.Place()) + tensor = tensor.to(device=device) + tensor.stop_gradient = not requires_grad + return tensor else: helper = LayerHelper("full_like", **locals()) check_variable_and_dtype( @@ -1235,7 +1250,12 @@ def fill_constant( def ones( - shape: ShapeLike, dtype: DTypeLike | None = None, name: str | None = None + shape: ShapeLike, + dtype: DTypeLike | None = None, + name: str | None = None, + *, + device: PlaceLike | None = None, + requires_grad: bool = False, ) -> paddle.Tensor: """ Create a Tensor of specified :attr:`shape` and :attr:`dtype` and fill it with 1. @@ -1247,6 +1267,10 @@ def ones( dtype (np.dtype|str, optional): Data type of output Tensor, it should be one of bool, float16, float32, float64, int32 and int64. If it is set to None, the data type will be float32. name(str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. + device(PlaceLike|None, optional): The desired device of returned tensor. + if None, uses the current device for the default tensor type (see paddle.device.set_device()). + device will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types. Default: None. + requires_grad(bool, optional): If autograd should record operations on the returned tensor. Default: False. Returns: Tensor: A Tensor of data type :attr:`dtype` with shape :attr:`shape` and all elements are 1. @@ -1281,11 +1305,22 @@ def ones( """ if dtype is None: dtype = paddle.get_default_dtype() - return fill_constant(value=1.0, shape=shape, dtype=dtype, name=name) + tensor = fill_constant(value=1.0, shape=shape, dtype=dtype, name=name) + + if device is not None: + tensor = tensor.to(device=device) + + tensor.stop_gradient = not requires_grad + return tensor def ones_like( - x: paddle.Tensor, dtype: DTypeLike | None = None, name: str | None = None + x: paddle.Tensor, + dtype: DTypeLike | None = None, + name: str | None = None, + *, + device: PlaceLike | None = None, + requires_grad: bool = False, ) -> paddle.Tensor: """ Returns a Tensor filled with the value 1, with the same shape and @@ -1299,6 +1334,10 @@ def ones_like( int32, int64. If ``dtype`` is None, the data type is the same as ``x``. Default is None. name(str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. + device(PlaceLike|None, optional): The desired device of returned tensor. + if None, uses the current device for the default tensor type (see paddle.device.set_device()). + device will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types. Default: None. + requires_grad(bool, optional): If autograd should record operations on the returned tensor. Default: False. Returns: Tensor: A Tensor filled with the value 1, with the same shape and @@ -1318,13 +1357,27 @@ def ones_like( [1 1 1] """ - return full_like(x=x, fill_value=1, dtype=dtype, name=name) + if dtype is None: + dtype = x.dtype + if requires_grad is None: + requires_grad = not x.stop_gradient + if device is None: + device = x.place + + tensor = full_like(x=x, fill_value=1, dtype=dtype, name=name) + + tensor = tensor.to(device=device) + tensor.stop_gradient = not requires_grad + return tensor def zeros( shape: ShapeLike, dtype: DTypeLike | None = None, name: str | None = None, + *, + device: PlaceLike | None = None, + requires_grad: bool = False, ) -> paddle.Tensor: """ Creates a tensor of specified :attr:`shape` and :attr:`dtype`, and fills it with 0. @@ -1337,6 +1390,10 @@ def zeros( bool, float16, float32, float64, int32 and int64. Default: if None, the data type is float32. name(str|None, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. + device(PlaceLike|None, optional): The desired device of returned tensor. + if None, uses the current device for the default tensor type (see paddle.device.set_device()). + device will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types. Default: None. + requires_grad(bool, optional): If autograd should record operations on the returned tensor. Default: False. Returns: Tensor: A tensor of data type :attr:`dtype` with shape :attr:`shape` and all elements set to 0. @@ -1371,11 +1428,22 @@ def zeros( """ if dtype is None: dtype = paddle.get_default_dtype() - return fill_constant(value=0.0, shape=shape, dtype=dtype, name=name) + tensor = fill_constant(value=0.0, shape=shape, dtype=dtype, name=name) + + if device is not None: + tensor = tensor.to(device=device) + + tensor.stop_gradient = not requires_grad + return tensor def zeros_like( - x: paddle.Tensor, dtype: DTypeLike | None = None, name: str | None = None + x: paddle.Tensor, + dtype: DTypeLike | None = None, + name: str | None = None, + *, + device: PlaceLike | None = None, + requires_grad: bool = False, ) -> paddle.Tensor: """ Returns a Tensor filled with the value 0, with the same shape and @@ -1389,6 +1457,10 @@ def zeros_like( int32, int64. If ``dtype`` is None, the data type is the same as ``x``. Default is None. name(str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. + device(PlaceLike|None, optional): The desired device of returned tensor. + if None, uses the current device for the default tensor type (see paddle.device.set_device()). + device will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types. Default: None. + requires_grad(bool, optional): If autograd should record operations on the returned tensor. Default: False. Returns: Tensor: A Tensor filled with the value 0, with the same shape and @@ -1409,7 +1481,18 @@ def zeros_like( [0 0 0] """ - return full_like(x=x, fill_value=0, dtype=dtype, name=name) + if dtype is None: + dtype = x.dtype + if requires_grad is None: + requires_grad = not x.stop_gradient + if device is None: + device = x.place + + tensor = full_like(x=x, fill_value=0, dtype=dtype, name=name) + + tensor = tensor.to(device=device) + tensor.stop_gradient = not requires_grad + return tensor def eye( @@ -1417,6 +1500,9 @@ def eye( num_columns: int | None = None, dtype: DTypeLike | None = None, name: str | None = None, + *, + device: PlaceLike | None = None, + requires_grad: bool = False, ) -> paddle.Tensor: """ @@ -1430,6 +1516,10 @@ def eye( It should be int32, int64, float16, float32, float64, complex64, complex128. Default: if None, the data type is float32. name(str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. + device(PlaceLike|None, optional): The desired device of returned tensor. + if None, uses the current device for the default tensor type (see paddle.device.set_device()). + device will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types. Default: None. + requires_grad(bool, optional): If autograd should record operations on the returned tensor. Default: False. Returns: Tensor: An identity Tensor or DenseTensor of shape [num_rows, num_columns]. @@ -1470,9 +1560,14 @@ def _check_attr(attr, message): num_columns = num_rows if in_dynamic_or_pir_mode(): - out = _C_ops.eye( + tensor = _C_ops.eye( num_rows, num_columns, dtype, _current_expected_place() ) + if device is not None: + tensor = tensor.to(device=device) + + tensor.stop_gradient = not requires_grad + return tensor else: helper = LayerHelper("eye", **locals()) check_dtype( @@ -1512,6 +1607,9 @@ def full( fill_value: bool | float | paddle.Tensor, dtype: DTypeLike | None = None, name: str | None = None, + *, + device: PlaceLike | None = None, + requires_grad: bool = False, ) -> paddle.Tensor: """ @@ -1527,6 +1625,10 @@ def full( which can be float16, float32, float64, int32, int64, if dtype is `None`, the data type of created Tensor is `float32`. name(str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. + device(PlaceLike|None, optional): The desired device of returned tensor. + if None, uses the current device for the default tensor type (see paddle.device.set_device()). + device will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types. Default: None. + requires_grad(bool, optional): If autograd should record operations on the returned tensor. Default: False. Returns: Tensor: Tensor which is created according to ``shape``, ``fill_value`` and ``dtype``. @@ -1576,7 +1678,14 @@ def full( else: dtype = paddle.get_default_dtype() - return fill_constant(shape=shape, dtype=dtype, value=fill_value, name=name) + tensor = fill_constant( + shape=shape, dtype=dtype, value=fill_value, name=name + ) + if device is not None: + tensor = tensor.to(device=device) + + tensor.stop_gradient = not requires_grad + return tensor def arange( @@ -2439,7 +2548,12 @@ def diag( def empty( - shape: ShapeLike, dtype: DTypeLike | None = None, name: str | None = None + shape: ShapeLike, + dtype: DTypeLike | None = None, + name: str | None = None, + *, + device: PlaceLike | None = None, + requires_grad: bool = False, ) -> paddle.Tensor: """ Returns a Tensor with uninitialized data which size is same as ``shape``. @@ -2453,6 +2567,10 @@ def empty( type of created Tensor use global default dtype (see ``get_default_dtype`` for details). name(str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. + device(PlaceLike|None, optional): The desired device of returned tensor. + if None, uses the current device for the default tensor type (see paddle.device.set_device()). + device will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types. Default: None. + requires_grad(bool, optional): If autograd should record operations on the returned tensor. Default: False. Returns: Tensor: Tensor which is created according to ``shape`` and ``dtype``, and is uninitialized. @@ -2529,11 +2647,14 @@ def empty( else: raise TypeError("Shape only supports Value, or list, or tuple.") - out = _C_ops.empty( + tensor = _C_ops.empty( shape, convert_np_dtype_to_dtype_(dtype), _current_expected_place() ) - out.stop_gradient = True - return out + if device is not None: + tensor = tensor.to(device=device) + + tensor.stop_gradient = not requires_grad + return tensor else: helper = LayerHelper("empty", **locals()) inputs = {} @@ -2581,7 +2702,12 @@ def empty( def empty_like( - x: paddle.Tensor, dtype: DTypeLike | None = None, name: str | None = None + x: paddle.Tensor, + dtype: DTypeLike | None = None, + name: str | None = None, + *, + device: PlaceLike | None = None, + requires_grad: bool = False, ) -> paddle.Tensor: """ Returns a Tensor with uninitialized data which has identical shape of ``x`` and ``dtype``. @@ -2593,6 +2719,10 @@ def empty_like( of bool, float16, float32, float64, int32, int64. The default value is None, which means the output data type is the same as input. name(str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. + device(PlaceLike|None, optional): The desired device of returned tensor. + if None, uses the current device for the default tensor type (see paddle.device.set_device()). + device will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types. Default: None. + requires_grad(bool, optional): If autograd should record operations on the returned tensor. Default: False. Returns: Tensor: Tensor which is created according to ``x`` and ``dtype``, and is uninitialized. @@ -2614,25 +2744,29 @@ def empty_like( if dtype is None: dtype = x.dtype + if requires_grad is None: + requires_grad = not x.stop_gradient + if device is None: + device = x.place + dtype = convert_dtype(dtype) - if in_dynamic_mode(): - out = _C_ops.empty( - x.shape, - convert_np_dtype_to_dtype_(dtype), - _current_expected_place(), - ) - out.stop_gradient = True - return out - elif in_pir_mode(): - shape = paddle.shape(x) - out = _C_ops.empty( - shape, + if in_dynamic_or_pir_mode(): + if in_dynamic_mode(): + x_shape = x.shape + else: + x_shape = paddle.shape(x) + + tensor = _C_ops.empty( + x_shape, convert_np_dtype_to_dtype_(dtype), _current_expected_place(), ) - out.stop_gradient = True - return out + + tensor = tensor.to(device=device) + tensor.stop_gradient = not requires_grad + return tensor + else: helper = LayerHelper("empty_like", **locals()) check_variable_and_dtype( diff --git a/test/legacy_test/test_creation.py b/test/legacy_test/test_creation.py new file mode 100644 index 00000000000000..2ff08a70f7794b --- /dev/null +++ b/test/legacy_test/test_creation.py @@ -0,0 +1,171 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import paddle + +# from .utils import dygraph_guard + + +class TestTensorCreation(unittest.TestCase): + def setUp(self): + self.devices = [paddle.CPUPlace()] + if paddle.device.is_compiled_with_cuda(): + self.devices.append(paddle.CUDAPlace(0)) + + self.requires_grads = [True, False] + self.dtypes = ["float32", paddle.float32, "int32", paddle.int32] + + def test_ones(self): + for device in self.devices: + for requires_grad in self.requires_grads: + for dtype in self.dtypes: + x = paddle.ones( + [2], + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + self.assertEqual(x.place, device) + self.assertEqual(x.stop_gradient, not requires_grad) + if isinstance(dtype, paddle.dtype): + self.assertEqual(x.dtype, dtype) + + def test_zeros(self): + for device in self.devices: + for requires_grad in self.requires_grads: + for dtype in self.dtypes: + x = paddle.zeros( + [2], + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + self.assertEqual(x.place, device) + self.assertEqual(x.stop_gradient, not requires_grad) + if isinstance(dtype, paddle.dtype): + self.assertEqual(x.dtype, dtype) + + def test_full(self): + for device in self.devices: + for requires_grad in self.requires_grads: + for dtype in self.dtypes: + x = paddle.full( + [2], + fill_value=3.14, + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + self.assertEqual(x.place, device) + self.assertEqual(x.stop_gradient, not requires_grad) + if isinstance(dtype, paddle.dtype): + self.assertEqual(x.dtype, dtype) + + def test_empty(self): + for device in self.devices: + for requires_grad in self.requires_grads: + for dtype in self.dtypes: + x = paddle.empty( + [2], + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + self.assertEqual(x.place, device) + self.assertEqual(x.stop_gradient, not requires_grad) + if isinstance(dtype, paddle.dtype): + self.assertEqual(x.dtype, dtype) + + def test_eye(self): + for device in self.devices: + for requires_grad in self.requires_grads: + for dtype in self.dtypes: + x = paddle.eye( + 3, + 3, + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + self.assertEqual(x.place, device) + self.assertEqual(x.stop_gradient, not requires_grad) + if isinstance(dtype, paddle.dtype): + self.assertEqual(x.dtype, dtype) + + def test_ones_like(self): + for device in self.devices: + for requires_grad in self.requires_grads: + for dtype in self.dtypes: + x = paddle.ones_like( + paddle.randn([2, 2]), + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + self.assertEqual(x.place, device) + self.assertEqual(x.stop_gradient, not requires_grad) + if isinstance(dtype, paddle.dtype): + self.assertEqual(x.dtype, dtype) + + def test_zeros_like(self): + for device in self.devices: + for requires_grad in self.requires_grads: + for dtype in self.dtypes: + x = paddle.zeros_like( + paddle.randn([2, 2]), + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + self.assertEqual(x.place, device) + self.assertEqual(x.stop_gradient, not requires_grad) + if isinstance(dtype, paddle.dtype): + self.assertEqual(x.dtype, dtype) + + def test_full_like(self): + for device in self.devices: + for requires_grad in self.requires_grads: + for dtype in self.dtypes: + x = paddle.full_like( + paddle.randn([2, 2]), + 3.14, + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + self.assertEqual(x.place, device) + self.assertEqual(x.stop_gradient, not requires_grad) + if isinstance(dtype, paddle.dtype): + self.assertEqual(x.dtype, dtype) + + def test_empty_like(self): + for device in self.devices: + for requires_grad in self.requires_grads: + for dtype in self.dtypes: + x = paddle.empty_like( + paddle.randn([2, 2]), + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + self.assertEqual(x.place, device) + self.assertEqual(x.stop_gradient, not requires_grad) + if isinstance(dtype, paddle.dtype): + self.assertEqual(x.dtype, dtype) + + +if __name__ == '__main__': + unittest.main() From 63920d851bdf00939b14c69bb17f4ec8bce369f7 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Mon, 4 Aug 2025 20:46:54 +0800 Subject: [PATCH 03/31] add static test --- python/paddle/tensor/creation.py | 24 +-- test/legacy_test/test_creation.py | 319 +++++++++++++++++++++--------- 2 files changed, 236 insertions(+), 107 deletions(-) diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index 579aefc6891a10..29b763adf15302 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -1090,7 +1090,8 @@ def full_like( tensor = _C_ops.full_like(x, fill_value, dtype, x.place) else: tensor = _C_ops.full_like(x, fill_value, dtype, core.Place()) - tensor = tensor.to(device=device) + if in_dynamic_mode(): + tensor = tensor.to(device=device) tensor.stop_gradient = not requires_grad return tensor else: @@ -1307,7 +1308,7 @@ def ones( dtype = paddle.get_default_dtype() tensor = fill_constant(value=1.0, shape=shape, dtype=dtype, name=name) - if device is not None: + if device is not None and in_dynamic_mode(): tensor = tensor.to(device=device) tensor.stop_gradient = not requires_grad @@ -1365,8 +1366,8 @@ def ones_like( device = x.place tensor = full_like(x=x, fill_value=1, dtype=dtype, name=name) - - tensor = tensor.to(device=device) + if in_dynamic_mode(): + tensor = tensor.to(device=device) tensor.stop_gradient = not requires_grad return tensor @@ -1430,7 +1431,7 @@ def zeros( dtype = paddle.get_default_dtype() tensor = fill_constant(value=0.0, shape=shape, dtype=dtype, name=name) - if device is not None: + if device is not None and in_dynamic_mode(): tensor = tensor.to(device=device) tensor.stop_gradient = not requires_grad @@ -1490,7 +1491,8 @@ def zeros_like( tensor = full_like(x=x, fill_value=0, dtype=dtype, name=name) - tensor = tensor.to(device=device) + if in_dynamic_mode(): + tensor = tensor.to(device=device) tensor.stop_gradient = not requires_grad return tensor @@ -1563,7 +1565,7 @@ def _check_attr(attr, message): tensor = _C_ops.eye( num_rows, num_columns, dtype, _current_expected_place() ) - if device is not None: + if device is not None and in_dynamic_mode(): tensor = tensor.to(device=device) tensor.stop_gradient = not requires_grad @@ -1681,7 +1683,7 @@ def full( tensor = fill_constant( shape=shape, dtype=dtype, value=fill_value, name=name ) - if device is not None: + if device is not None and in_dynamic_mode(): tensor = tensor.to(device=device) tensor.stop_gradient = not requires_grad @@ -2650,7 +2652,7 @@ def empty( tensor = _C_ops.empty( shape, convert_np_dtype_to_dtype_(dtype), _current_expected_place() ) - if device is not None: + if device is not None and in_dynamic_mode(): tensor = tensor.to(device=device) tensor.stop_gradient = not requires_grad @@ -2762,8 +2764,8 @@ def empty_like( convert_np_dtype_to_dtype_(dtype), _current_expected_place(), ) - - tensor = tensor.to(device=device) + if in_dynamic_mode(): + tensor = tensor.to(device=device) tensor.stop_gradient = not requires_grad return tensor diff --git a/test/legacy_test/test_creation.py b/test/legacy_test/test_creation.py index 2ff08a70f7794b..e9a0091acaf03a 100644 --- a/test/legacy_test/test_creation.py +++ b/test/legacy_test/test_creation.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,9 +14,9 @@ import unittest -import paddle +from utils import dygraph_guard -# from .utils import dygraph_guard +import paddle class TestTensorCreation(unittest.TestCase): @@ -32,139 +32,266 @@ def test_ones(self): for device in self.devices: for requires_grad in self.requires_grads: for dtype in self.dtypes: - x = paddle.ones( - [2], - dtype=dtype, - requires_grad=requires_grad, - device=device, - ) - self.assertEqual(x.place, device) - self.assertEqual(x.stop_gradient, not requires_grad) - if isinstance(dtype, paddle.dtype): - self.assertEqual(x.dtype, dtype) + with dygraph_guard(): + x = paddle.ones( + [2], + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + self.assertEqual(x.place, device) + self.assertEqual(x.stop_gradient, not requires_grad) + if isinstance(dtype, paddle.dtype): + self.assertEqual(x.dtype, dtype) + st_f = paddle.jit.to_static( + paddle.ones, full_graph=True + ) + x = st_f( + [2], + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + self.assertEqual(x.place, device) + self.assertEqual(x.stop_gradient, not requires_grad) + if isinstance(dtype, paddle.dtype): + self.assertEqual(x.dtype, dtype) def test_zeros(self): for device in self.devices: for requires_grad in self.requires_grads: for dtype in self.dtypes: - x = paddle.zeros( - [2], - dtype=dtype, - requires_grad=requires_grad, - device=device, - ) - self.assertEqual(x.place, device) - self.assertEqual(x.stop_gradient, not requires_grad) - if isinstance(dtype, paddle.dtype): - self.assertEqual(x.dtype, dtype) + with dygraph_guard(): + x = paddle.zeros( + [2], + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + self.assertEqual(x.place, device) + self.assertEqual(x.stop_gradient, not requires_grad) + if isinstance(dtype, paddle.dtype): + self.assertEqual(x.dtype, dtype) + st_f = paddle.jit.to_static( + paddle.zeros, full_graph=True + ) + x = st_f( + [2], + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + self.assertEqual(x.place, device) + self.assertEqual(x.stop_gradient, not requires_grad) + if isinstance(dtype, paddle.dtype): + self.assertEqual(x.dtype, dtype) def test_full(self): for device in self.devices: for requires_grad in self.requires_grads: for dtype in self.dtypes: - x = paddle.full( - [2], - fill_value=3.14, - dtype=dtype, - requires_grad=requires_grad, - device=device, - ) - self.assertEqual(x.place, device) - self.assertEqual(x.stop_gradient, not requires_grad) - if isinstance(dtype, paddle.dtype): - self.assertEqual(x.dtype, dtype) + with dygraph_guard(): + x = paddle.full( + [2], + fill_value=3.14, + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + self.assertEqual(x.place, device) + self.assertEqual(x.stop_gradient, not requires_grad) + if isinstance(dtype, paddle.dtype): + self.assertEqual(x.dtype, dtype) + st_f = paddle.jit.to_static( + paddle.full, full_graph=True + ) + x = st_f( + [2], + fill_value=3.14, + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + self.assertEqual(x.place, device) + self.assertEqual(x.stop_gradient, not requires_grad) + if isinstance(dtype, paddle.dtype): + self.assertEqual(x.dtype, dtype) def test_empty(self): for device in self.devices: for requires_grad in self.requires_grads: for dtype in self.dtypes: - x = paddle.empty( - [2], - dtype=dtype, - requires_grad=requires_grad, - device=device, - ) - self.assertEqual(x.place, device) - self.assertEqual(x.stop_gradient, not requires_grad) - if isinstance(dtype, paddle.dtype): - self.assertEqual(x.dtype, dtype) + with dygraph_guard(): + x = paddle.empty( + [2], + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + self.assertEqual(x.place, device) + self.assertEqual(x.stop_gradient, not requires_grad) + if isinstance(dtype, paddle.dtype): + self.assertEqual(x.dtype, dtype) + st_f = paddle.jit.to_static( + paddle.empty, full_graph=True + ) + x = st_f( + [2], + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + self.assertEqual(x.place, device) + self.assertEqual(x.stop_gradient, not requires_grad) + if isinstance(dtype, paddle.dtype): + self.assertEqual(x.dtype, dtype) def test_eye(self): for device in self.devices: for requires_grad in self.requires_grads: for dtype in self.dtypes: - x = paddle.eye( - 3, - 3, - dtype=dtype, - requires_grad=requires_grad, - device=device, - ) - self.assertEqual(x.place, device) - self.assertEqual(x.stop_gradient, not requires_grad) - if isinstance(dtype, paddle.dtype): - self.assertEqual(x.dtype, dtype) + with dygraph_guard(): + x = paddle.eye( + 3, + 3, + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + self.assertEqual(x.place, device) + self.assertEqual(x.stop_gradient, not requires_grad) + if isinstance(dtype, paddle.dtype): + self.assertEqual(x.dtype, dtype) + st_f = paddle.jit.to_static(paddle.eye, full_graph=True) + x = st_f( + 3, + 3, + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + self.assertEqual(x.place, device) + self.assertEqual(x.stop_gradient, not requires_grad) + if isinstance(dtype, paddle.dtype): + self.assertEqual(x.dtype, dtype) def test_ones_like(self): for device in self.devices: for requires_grad in self.requires_grads: for dtype in self.dtypes: - x = paddle.ones_like( - paddle.randn([2, 2]), - dtype=dtype, - requires_grad=requires_grad, - device=device, - ) - self.assertEqual(x.place, device) - self.assertEqual(x.stop_gradient, not requires_grad) - if isinstance(dtype, paddle.dtype): - self.assertEqual(x.dtype, dtype) + with dygraph_guard(): + x = paddle.ones_like( + paddle.randn([2, 2]), + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + self.assertEqual(x.place, device) + self.assertEqual(x.stop_gradient, not requires_grad) + if isinstance(dtype, paddle.dtype): + self.assertEqual(x.dtype, dtype) + st_f = paddle.jit.to_static( + paddle.ones_like, full_graph=True + ) + x = st_f( + paddle.randn([2, 2]), + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + self.assertEqual(x.place, device) + self.assertEqual(x.stop_gradient, not requires_grad) + if isinstance(dtype, paddle.dtype): + self.assertEqual(x.dtype, dtype) def test_zeros_like(self): for device in self.devices: for requires_grad in self.requires_grads: for dtype in self.dtypes: - x = paddle.zeros_like( - paddle.randn([2, 2]), - dtype=dtype, - requires_grad=requires_grad, - device=device, - ) - self.assertEqual(x.place, device) - self.assertEqual(x.stop_gradient, not requires_grad) - if isinstance(dtype, paddle.dtype): - self.assertEqual(x.dtype, dtype) + with dygraph_guard(): + x = paddle.zeros_like( + paddle.randn([2, 2]), + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + self.assertEqual(x.place, device) + self.assertEqual(x.stop_gradient, not requires_grad) + if isinstance(dtype, paddle.dtype): + self.assertEqual(x.dtype, dtype) + st_f = paddle.jit.to_static( + paddle.zeros_like, full_graph=True + ) + x = st_f( + paddle.randn([2, 2]), + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + self.assertEqual(x.place, device) + self.assertEqual(x.stop_gradient, not requires_grad) + if isinstance(dtype, paddle.dtype): + self.assertEqual(x.dtype, dtype) def test_full_like(self): for device in self.devices: for requires_grad in self.requires_grads: for dtype in self.dtypes: - x = paddle.full_like( - paddle.randn([2, 2]), - 3.14, - dtype=dtype, - requires_grad=requires_grad, - device=device, - ) - self.assertEqual(x.place, device) - self.assertEqual(x.stop_gradient, not requires_grad) - if isinstance(dtype, paddle.dtype): - self.assertEqual(x.dtype, dtype) + with dygraph_guard(): + x = paddle.full_like( + paddle.randn([2, 2]), + 3.14, + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + self.assertEqual(x.place, device) + self.assertEqual(x.stop_gradient, not requires_grad) + if isinstance(dtype, paddle.dtype): + self.assertEqual(x.dtype, dtype) + st_f = paddle.jit.to_static( + paddle.full_like, full_graph=True + ) + x = st_f( + paddle.randn([2, 2]), + 3.14, + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + self.assertEqual(x.place, device) + self.assertEqual(x.stop_gradient, not requires_grad) + if isinstance(dtype, paddle.dtype): + self.assertEqual(x.dtype, dtype) def test_empty_like(self): for device in self.devices: for requires_grad in self.requires_grads: for dtype in self.dtypes: - x = paddle.empty_like( - paddle.randn([2, 2]), - dtype=dtype, - requires_grad=requires_grad, - device=device, - ) - self.assertEqual(x.place, device) - self.assertEqual(x.stop_gradient, not requires_grad) - if isinstance(dtype, paddle.dtype): - self.assertEqual(x.dtype, dtype) + with dygraph_guard(): + x = paddle.empty_like( + paddle.randn([2, 2]), + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + self.assertEqual(x.place, device) + self.assertEqual(x.stop_gradient, not requires_grad) + if isinstance(dtype, paddle.dtype): + self.assertEqual(x.dtype, dtype) + st_f = paddle.jit.to_static( + paddle.empty_like, full_graph=True + ) + x = st_f( + paddle.randn([2, 2]), + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + self.assertEqual(x.place, device) + self.assertEqual(x.stop_gradient, not requires_grad) + if isinstance(dtype, paddle.dtype): + self.assertEqual(x.dtype, dtype) if __name__ == '__main__': From 78499c82660e4c0a1e6977fd1b6b42676a92a14a Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Tue, 5 Aug 2025 10:47:46 +0800 Subject: [PATCH 04/31] fix UT --- test/legacy_test/test_creation.py | 15 +++++---------- test/legacy_test/test_empty_like_op.py | 10 ++++++++++ 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/test/legacy_test/test_creation.py b/test/legacy_test/test_creation.py index e9a0091acaf03a..47983740ee843c 100644 --- a/test/legacy_test/test_creation.py +++ b/test/legacy_test/test_creation.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,6 +24,10 @@ def setUp(self): self.devices = [paddle.CPUPlace()] if paddle.device.is_compiled_with_cuda(): self.devices.append(paddle.CUDAPlace(0)) + if paddle.device.is_compiled_with_xpu(): + self.devices.append(paddle.device.XPUPlace(0)) + if paddle.device.is_compiled_with_ipu(): + self.devices.append(paddle.device.IPUPlace()) self.requires_grads = [True, False] self.dtypes = ["float32", paddle.float32, "int32", paddle.int32] @@ -52,7 +56,6 @@ def test_ones(self): requires_grad=requires_grad, device=device, ) - self.assertEqual(x.place, device) self.assertEqual(x.stop_gradient, not requires_grad) if isinstance(dtype, paddle.dtype): self.assertEqual(x.dtype, dtype) @@ -81,7 +84,6 @@ def test_zeros(self): requires_grad=requires_grad, device=device, ) - self.assertEqual(x.place, device) self.assertEqual(x.stop_gradient, not requires_grad) if isinstance(dtype, paddle.dtype): self.assertEqual(x.dtype, dtype) @@ -112,7 +114,6 @@ def test_full(self): requires_grad=requires_grad, device=device, ) - self.assertEqual(x.place, device) self.assertEqual(x.stop_gradient, not requires_grad) if isinstance(dtype, paddle.dtype): self.assertEqual(x.dtype, dtype) @@ -141,7 +142,6 @@ def test_empty(self): requires_grad=requires_grad, device=device, ) - self.assertEqual(x.place, device) self.assertEqual(x.stop_gradient, not requires_grad) if isinstance(dtype, paddle.dtype): self.assertEqual(x.dtype, dtype) @@ -170,7 +170,6 @@ def test_eye(self): requires_grad=requires_grad, device=device, ) - self.assertEqual(x.place, device) self.assertEqual(x.stop_gradient, not requires_grad) if isinstance(dtype, paddle.dtype): self.assertEqual(x.dtype, dtype) @@ -199,7 +198,6 @@ def test_ones_like(self): requires_grad=requires_grad, device=device, ) - self.assertEqual(x.place, device) self.assertEqual(x.stop_gradient, not requires_grad) if isinstance(dtype, paddle.dtype): self.assertEqual(x.dtype, dtype) @@ -228,7 +226,6 @@ def test_zeros_like(self): requires_grad=requires_grad, device=device, ) - self.assertEqual(x.place, device) self.assertEqual(x.stop_gradient, not requires_grad) if isinstance(dtype, paddle.dtype): self.assertEqual(x.dtype, dtype) @@ -259,7 +256,6 @@ def test_full_like(self): requires_grad=requires_grad, device=device, ) - self.assertEqual(x.place, device) self.assertEqual(x.stop_gradient, not requires_grad) if isinstance(dtype, paddle.dtype): self.assertEqual(x.dtype, dtype) @@ -288,7 +284,6 @@ def test_empty_like(self): requires_grad=requires_grad, device=device, ) - self.assertEqual(x.place, device) self.assertEqual(x.stop_gradient, not requires_grad) if isinstance(dtype, paddle.dtype): self.assertEqual(x.dtype, dtype) diff --git a/test/legacy_test/test_empty_like_op.py b/test/legacy_test/test_empty_like_op.py index 51b68758ea28fe..c197caa24a7442 100644 --- a/test/legacy_test/test_empty_like_op.py +++ b/test/legacy_test/test_empty_like_op.py @@ -91,6 +91,7 @@ def init_config(self): self.dtype = self.x.dtype self.dst_shape = self.x.shape self.dst_dtype = self.dtype + self.x = paddle.to_tensor(self.x) class TestEmptyLikeAPI2(TestEmptyLikeAPI): @@ -99,6 +100,7 @@ def init_config(self): self.dtype = self.x.dtype self.dst_shape = self.x.shape self.dst_dtype = self.dtype + self.x = paddle.to_tensor(self.x) class TestEmptyLikeAPI3(TestEmptyLikeAPI): @@ -107,6 +109,7 @@ def init_config(self): self.dtype = self.x.dtype self.dst_shape = self.x.shape self.dst_dtype = self.dtype + self.x = paddle.to_tensor(self.x) class TestEmptyLikeAPI4(TestEmptyLikeAPI): @@ -115,6 +118,7 @@ def init_config(self): self.dtype = self.x.dtype self.dst_shape = self.x.shape self.dst_dtype = self.dtype + self.x = paddle.to_tensor(self.x) class TestEmptyLikeAPI5(TestEmptyLikeAPI): @@ -123,6 +127,7 @@ def init_config(self): self.dtype = self.x.dtype self.dst_shape = self.x.shape self.dst_dtype = self.dtype + self.x = paddle.to_tensor(self.x) class TestEmptyLikeAPI6(TestEmptyLikeAPI): @@ -131,6 +136,7 @@ def init_config(self): self.dtype = "float32" self.dst_shape = self.x.shape self.dst_dtype = self.dtype + self.x = paddle.to_tensor(self.x) class TestEmptyLikeAPI7(TestEmptyLikeAPI): @@ -139,6 +145,7 @@ def init_config(self): self.dtype = "float32" self.dst_shape = self.x.shape self.dst_dtype = self.dtype + self.x = paddle.to_tensor(self.x) class TestEmptyLikeAPI8(TestEmptyLikeAPI): @@ -147,6 +154,7 @@ def init_config(self): self.dtype = "float32" self.dst_shape = self.x.shape self.dst_dtype = self.dtype + self.x = paddle.to_tensor(self.x) class TestEmptyLikeAPI9(TestEmptyLikeAPI): @@ -155,6 +163,7 @@ def init_config(self): self.dtype = "float32" self.dst_shape = self.x.shape self.dst_dtype = self.dtype + self.x = paddle.to_tensor(self.x) class TestEmptyLikeAPI10(TestEmptyLikeAPI): @@ -163,6 +172,7 @@ def init_config(self): self.dtype = "bool" self.dst_shape = self.x.shape self.dst_dtype = self.dtype + self.x = paddle.to_tensor(self.x) class TestEmptyLikeAPI_Static(TestEmptyLikeAPICommon): From 22e216159107e6861ae1fda42e8e46f0c49a2c7d Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Tue, 5 Aug 2025 10:50:01 +0800 Subject: [PATCH 05/31] fix date --- test/legacy_test/test_creation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/legacy_test/test_creation.py b/test/legacy_test/test_creation.py index 47983740ee843c..fd24f1618a0c7f 100644 --- a/test/legacy_test/test_creation.py +++ b/test/legacy_test/test_creation.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 8ebcccec15f67a4620b1eb53de11bcb914f19b7d Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Tue, 5 Aug 2025 13:16:19 +0800 Subject: [PATCH 06/31] refine code --- python/paddle/tensor/creation.py | 76 ++++++++++++++------------------ 1 file changed, 34 insertions(+), 42 deletions(-) diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index 29b763adf15302..12716b2c80c34c 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -1087,11 +1087,9 @@ def full_like( if in_dynamic_or_pir_mode(): if in_dynamic_mode(): - tensor = _C_ops.full_like(x, fill_value, dtype, x.place) + tensor = _C_ops.full_like(x, fill_value, dtype, device) else: tensor = _C_ops.full_like(x, fill_value, dtype, core.Place()) - if in_dynamic_mode(): - tensor = tensor.to(device=device) tensor.stop_gradient = not requires_grad return tensor else: @@ -1358,18 +1356,14 @@ def ones_like( [1 1 1] """ - if dtype is None: - dtype = x.dtype - if requires_grad is None: - requires_grad = not x.stop_gradient - if device is None: - device = x.place - - tensor = full_like(x=x, fill_value=1, dtype=dtype, name=name) - if in_dynamic_mode(): - tensor = tensor.to(device=device) - tensor.stop_gradient = not requires_grad - return tensor + return full_like( + x=x, + fill_value=1, + dtype=dtype, + name=name, + device=device, + requires_grad=requires_grad, + ) def zeros( @@ -1482,19 +1476,14 @@ def zeros_like( [0 0 0] """ - if dtype is None: - dtype = x.dtype - if requires_grad is None: - requires_grad = not x.stop_gradient - if device is None: - device = x.place - - tensor = full_like(x=x, fill_value=0, dtype=dtype, name=name) - - if in_dynamic_mode(): - tensor = tensor.to(device=device) - tensor.stop_gradient = not requires_grad - return tensor + return full_like( + x=x, + fill_value=0, + dtype=dtype, + name=name, + device=device, + requires_grad=requires_grad, + ) def eye( @@ -1563,11 +1552,15 @@ def _check_attr(attr, message): if in_dynamic_or_pir_mode(): tensor = _C_ops.eye( - num_rows, num_columns, dtype, _current_expected_place() + num_rows, + num_columns, + dtype, + ( + device + if (in_dynamic_mode() and device is not None) + else _current_expected_place() + ), ) - if device is not None and in_dynamic_mode(): - tensor = tensor.to(device=device) - tensor.stop_gradient = not requires_grad return tensor else: @@ -2650,11 +2643,14 @@ def empty( raise TypeError("Shape only supports Value, or list, or tuple.") tensor = _C_ops.empty( - shape, convert_np_dtype_to_dtype_(dtype), _current_expected_place() + shape, + convert_np_dtype_to_dtype_(dtype), + ( + device + if (in_dynamic_mode() and device is not None) + else _current_expected_place() + ), ) - if device is not None and in_dynamic_mode(): - tensor = tensor.to(device=device) - tensor.stop_gradient = not requires_grad return tensor else: @@ -2743,14 +2739,12 @@ def empty_like( [[1.8491974e+20 1.8037303e+28 1.7443726e+28] [4.9640171e+28 3.0186127e+32 5.6715899e-11]] """ - if dtype is None: dtype = x.dtype if requires_grad is None: requires_grad = not x.stop_gradient - if device is None: + if in_dynamic_mode() and device is None: device = x.place - dtype = convert_dtype(dtype) if in_dynamic_or_pir_mode(): @@ -2762,10 +2756,8 @@ def empty_like( tensor = _C_ops.empty( x_shape, convert_np_dtype_to_dtype_(dtype), - _current_expected_place(), + (device if in_dynamic_mode() else _current_expected_place()), ) - if in_dynamic_mode(): - tensor = tensor.to(device=device) tensor.stop_gradient = not requires_grad return tensor From 249d45e6e394e69fc9a967f2a2ba7b09f86c475e Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Tue, 5 Aug 2025 14:39:50 +0800 Subject: [PATCH 07/31] fix --- python/paddle/tensor/creation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index 12716b2c80c34c..c207dff8e02134 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -2743,7 +2743,7 @@ def empty_like( dtype = x.dtype if requires_grad is None: requires_grad = not x.stop_gradient - if in_dynamic_mode() and device is None: + if device is None: device = x.place dtype = convert_dtype(dtype) From 3a758f573558e22f945659d256dc593c3f1d6ae2 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Tue, 5 Aug 2025 19:14:28 +0800 Subject: [PATCH 08/31] fix UT --- test/legacy_test/test_empty_like_op.py | 132 ++++++++++++------------- 1 file changed, 65 insertions(+), 67 deletions(-) diff --git a/test/legacy_test/test_empty_like_op.py b/test/legacy_test/test_empty_like_op.py index c197caa24a7442..fcf5335d2899c7 100644 --- a/test/legacy_test/test_empty_like_op.py +++ b/test/legacy_test/test_empty_like_op.py @@ -16,6 +16,7 @@ import numpy as np from op_test import convert_uint16_to_float +from utils import dygraph_guard, static_guard import paddle from paddle.base import core @@ -81,10 +82,9 @@ def setUp(self): self.init_config() def test_dygraph_api_out(self): - paddle.disable_static() - out = paddle.empty_like(self.x, self.dtype) - self.__check_out__(out.numpy()) - paddle.enable_static() + with dygraph_guard(): + out = paddle.empty_like(self.x, self.dtype) + self.__check_out__(out.numpy()) def init_config(self): self.x = np.random.random((200, 3)).astype("float32") @@ -180,31 +180,29 @@ def setUp(self): self.init_config() def test_static_graph(self): - paddle.enable_static() - train_program = paddle.static.Program() - startup_program = paddle.static.Program() - - with paddle.static.program_guard(train_program, startup_program): - x = np.random.random(self.x_shape).astype(self.dtype) - data_x = paddle.static.data( - 'x', shape=self.data_x_shape, dtype=self.dtype - ) + with static_guard(): + train_program = paddle.static.Program() + startup_program = paddle.static.Program() - out = paddle.empty_like(data_x) + with paddle.static.program_guard(train_program, startup_program): + x = np.random.random(self.x_shape).astype(self.dtype) + data_x = paddle.static.data( + 'x', shape=self.data_x_shape, dtype=self.dtype + ) - place = ( - paddle.CUDAPlace(0) - if core.is_compiled_with_cuda() - else paddle.CPUPlace() - ) - exe = paddle.static.Executor(place) - res = exe.run(train_program, feed={'x': x}, fetch_list=[out]) + out = paddle.empty_like(data_x) - self.dst_dtype = self.dtype - self.dst_shape = x.shape - self.__check_out__(res[0]) + place = ( + paddle.CUDAPlace(0) + if core.is_compiled_with_cuda() + else paddle.CPUPlace() + ) + exe = paddle.static.Executor(place) + res = exe.run(train_program, feed={'x': x}, fetch_list=[out]) - paddle.disable_static() + self.dst_dtype = self.dtype + self.dst_shape = x.shape + self.__check_out__(res[0]) def init_config(self): self.x_shape = (200, 3) @@ -229,27 +227,27 @@ def init_config(self): self.dtype = 'float16' def test_static_graph(self): - paddle.enable_static() - if paddle.base.core.is_compiled_with_cuda(): - place = paddle.CUDAPlace(0) - with paddle.static.program_guard( - paddle.static.Program(), paddle.static.Program() - ): - x = np.random.random([200, 3]).astype(self.dtype) - data_x = paddle.static.data( - name="x", shape=[200, 3], dtype=self.dtype - ) - out = paddle.empty_like(data_x) - exe = paddle.static.Executor(place) - res = exe.run( - paddle.static.default_main_program(), - feed={'x': x}, - fetch_list=[out], - ) - - self.dst_dtype = self.dtype - self.dst_shape = x.shape - self.__check_out__(res[0]) + with static_guard(): + if paddle.base.core.is_compiled_with_cuda(): + place = paddle.CUDAPlace(0) + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x = np.random.random([200, 3]).astype(self.dtype) + data_x = paddle.static.data( + name="x", shape=[200, 3], dtype=self.dtype + ) + out = paddle.empty_like(data_x) + exe = paddle.static.Executor(place) + res = exe.run( + paddle.static.default_main_program(), + feed={'x': x}, + fetch_list=[out], + ) + + self.dst_dtype = self.dtype + self.dst_shape = x.shape + self.__check_out__(res[0]) class TestEmptyLikeAPI_StaticForBF16Op(TestEmptyLikeAPICommon): @@ -262,27 +260,27 @@ def init_config(self): self.dtype = 'uint16' def test_static_graph(self): - paddle.enable_static() - if paddle.base.core.is_compiled_with_cuda(): - place = paddle.CUDAPlace(0) - with paddle.static.program_guard( - paddle.static.Program(), paddle.static.Program() - ): - x = np.random.random([200, 3]).astype(np.uint16) - data_x = paddle.static.data( - name="x", shape=[200, 3], dtype=np.uint16 - ) - out = paddle.empty_like(data_x) - exe = paddle.static.Executor(place) - res = exe.run( - paddle.static.default_main_program(), - feed={'x': x}, - fetch_list=[out], - ) - - self.dst_dtype = self.dtype - self.dst_shape = x.shape - self.__check_out__(res[0]) + with static_guard(): + if paddle.base.core.is_compiled_with_cuda(): + place = paddle.CUDAPlace(0) + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x = np.random.random([200, 3]).astype(np.uint16) + data_x = paddle.static.data( + name="x", shape=[200, 3], dtype=np.uint16 + ) + out = paddle.empty_like(data_x) + exe = paddle.static.Executor(place) + res = exe.run( + paddle.static.default_main_program(), + feed={'x': x}, + fetch_list=[out], + ) + + self.dst_dtype = self.dtype + self.dst_shape = x.shape + self.__check_out__(res[0]) if __name__ == '__main__': From 2c9f8db2e9b21cb105b1c9255633dd77a07d12ce Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Tue, 5 Aug 2025 20:29:05 +0800 Subject: [PATCH 09/31] fix --- python/paddle/tensor/creation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index c207dff8e02134..41d74079173842 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -2756,7 +2756,7 @@ def empty_like( tensor = _C_ops.empty( x_shape, convert_np_dtype_to_dtype_(dtype), - (device if in_dynamic_mode() else _current_expected_place()), + (device or _current_expected_place()), ) tensor.stop_gradient = not requires_grad return tensor From 4212a2e13152826c7e6aa40b0095f19f40301fd6 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Tue, 5 Aug 2025 22:44:42 +0800 Subject: [PATCH 10/31] fix BatchNormDoubleGradKernel --- .../phi/kernels/cpu/batch_norm_grad_kernel.cc | 28 +++++++++++++------ 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/paddle/phi/kernels/cpu/batch_norm_grad_kernel.cc b/paddle/phi/kernels/cpu/batch_norm_grad_kernel.cc index ecc3cc4df61b13..8b17add7077d02 100644 --- a/paddle/phi/kernels/cpu/batch_norm_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/batch_norm_grad_kernel.cc @@ -408,8 +408,6 @@ void BatchNormDoubleGradKernel( auto* dX = x_grad; auto* dScale = scale_grad; auto* ddY = y_grad_grad; - dev_ctx.template Alloc(dX); - dev_ctx.template Alloc(ddY); const auto& x_dims = X->dims(); const int C = static_cast( @@ -440,8 +438,14 @@ void BatchNormDoubleGradKernel( DenseTensor transformed_dy(dY->type()); DenseTensor transformed_ddx(ddX->type()); - DenseTensor transformed_dx(dX->type()); - DenseTensor transformed_ddy(ddY->type()); + DenseTensor transformed_dx; + if (dX) { + transformed_dx = DenseTensor(dX->type()); + } + DenseTensor transformed_ddy; + if (ddY) { + transformed_ddy = DenseTensor(ddY->type()); + } if (data_layout == DataLayout::kNCHW && x_dims.size() > 2) { VLOG(3) << "Transform batchnorm output from NCHW to NHWC"; // Input Tensor @@ -452,15 +456,23 @@ void BatchNormDoubleGradKernel( ResizeToChannelLast(dev_ctx, ddX, &transformed_ddx); TransToChannelLast(dev_ctx, ddX, &transformed_ddx); // Output Tensor - ResizeToChannelLast(dev_ctx, dX, &transformed_dx); - ResizeToChannelLast(dev_ctx, ddY, &transformed_ddy); + if (dX) { + ResizeToChannelLast(dev_ctx, dX, &transformed_dx); + } + if (ddY) { + ResizeToChannelLast(dev_ctx, ddY, &transformed_ddy); + } } else { transformed_x.ShareDataWith(*X); transformed_dy.ShareDataWith(*dY); transformed_ddx.ShareDataWith(*ddX); - transformed_dx.ShareDataWith(*dX); - transformed_ddy.ShareDataWith(*ddY); + if (dX) { + transformed_dx.ShareDataWith(*dX); + } + if (ddY) { + transformed_ddy.ShareDataWith(*ddY); + } } ConstEigenArrayMap x_arr(transformed_x.data(), C, sample_size); From dbc7caeafb5f2655202918dd306bb84386e9b1a6 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Wed, 6 Aug 2025 19:09:32 +0800 Subject: [PATCH 11/31] restore code --- .../phi/kernels/cpu/batch_norm_grad_kernel.cc | 28 ++++++------------- 1 file changed, 8 insertions(+), 20 deletions(-) diff --git a/paddle/phi/kernels/cpu/batch_norm_grad_kernel.cc b/paddle/phi/kernels/cpu/batch_norm_grad_kernel.cc index 8b17add7077d02..ecc3cc4df61b13 100644 --- a/paddle/phi/kernels/cpu/batch_norm_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/batch_norm_grad_kernel.cc @@ -408,6 +408,8 @@ void BatchNormDoubleGradKernel( auto* dX = x_grad; auto* dScale = scale_grad; auto* ddY = y_grad_grad; + dev_ctx.template Alloc(dX); + dev_ctx.template Alloc(ddY); const auto& x_dims = X->dims(); const int C = static_cast( @@ -438,14 +440,8 @@ void BatchNormDoubleGradKernel( DenseTensor transformed_dy(dY->type()); DenseTensor transformed_ddx(ddX->type()); - DenseTensor transformed_dx; - if (dX) { - transformed_dx = DenseTensor(dX->type()); - } - DenseTensor transformed_ddy; - if (ddY) { - transformed_ddy = DenseTensor(ddY->type()); - } + DenseTensor transformed_dx(dX->type()); + DenseTensor transformed_ddy(ddY->type()); if (data_layout == DataLayout::kNCHW && x_dims.size() > 2) { VLOG(3) << "Transform batchnorm output from NCHW to NHWC"; // Input Tensor @@ -456,23 +452,15 @@ void BatchNormDoubleGradKernel( ResizeToChannelLast(dev_ctx, ddX, &transformed_ddx); TransToChannelLast(dev_ctx, ddX, &transformed_ddx); // Output Tensor - if (dX) { - ResizeToChannelLast(dev_ctx, dX, &transformed_dx); - } - if (ddY) { - ResizeToChannelLast(dev_ctx, ddY, &transformed_ddy); - } + ResizeToChannelLast(dev_ctx, dX, &transformed_dx); + ResizeToChannelLast(dev_ctx, ddY, &transformed_ddy); } else { transformed_x.ShareDataWith(*X); transformed_dy.ShareDataWith(*dY); transformed_ddx.ShareDataWith(*ddX); - if (dX) { - transformed_dx.ShareDataWith(*dX); - } - if (ddY) { - transformed_ddy.ShareDataWith(*ddY); - } + transformed_dx.ShareDataWith(*dX); + transformed_ddy.ShareDataWith(*ddY); } ConstEigenArrayMap x_arr(transformed_x.data(), C, sample_size); From db536cdaccc104b162b4c6e823594b0a679eba01 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Wed, 6 Aug 2025 20:10:12 +0800 Subject: [PATCH 12/31] fix --- python/paddle/tensor/creation.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index 41d74079173842..030cbec8e271a9 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -1080,8 +1080,6 @@ def full_like( else: if not isinstance(dtype, (core.VarDesc.VarType, core.DataType)): dtype = convert_np_dtype_to_dtype_(dtype) - if requires_grad is None: - requires_grad = not x.stop_gradient if device is None: device = x.place @@ -1090,7 +1088,8 @@ def full_like( tensor = _C_ops.full_like(x, fill_value, dtype, device) else: tensor = _C_ops.full_like(x, fill_value, dtype, core.Place()) - tensor.stop_gradient = not requires_grad + if requires_grad: + tensor.stop_gradient = False return tensor else: helper = LayerHelper("full_like", **locals()) @@ -1308,8 +1307,8 @@ def ones( if device is not None and in_dynamic_mode(): tensor = tensor.to(device=device) - - tensor.stop_gradient = not requires_grad + if requires_grad: + tensor.stop_gradient = False return tensor @@ -1427,8 +1426,8 @@ def zeros( if device is not None and in_dynamic_mode(): tensor = tensor.to(device=device) - - tensor.stop_gradient = not requires_grad + if requires_grad: + tensor.stop_gradient = True return tensor @@ -1561,7 +1560,8 @@ def _check_attr(attr, message): else _current_expected_place() ), ) - tensor.stop_gradient = not requires_grad + if requires_grad: + tensor.stop_gradient = False return tensor else: helper = LayerHelper("eye", **locals()) @@ -1678,8 +1678,8 @@ def full( ) if device is not None and in_dynamic_mode(): tensor = tensor.to(device=device) - - tensor.stop_gradient = not requires_grad + if requires_grad: + tensor.stop_gradient = False return tensor @@ -2651,7 +2651,8 @@ def empty( else _current_expected_place() ), ) - tensor.stop_gradient = not requires_grad + if requires_grad: + tensor.stop_gradient = False return tensor else: helper = LayerHelper("empty", **locals()) @@ -2741,8 +2742,6 @@ def empty_like( """ if dtype is None: dtype = x.dtype - if requires_grad is None: - requires_grad = not x.stop_gradient if device is None: device = x.place dtype = convert_dtype(dtype) @@ -2758,7 +2757,8 @@ def empty_like( convert_np_dtype_to_dtype_(dtype), (device or _current_expected_place()), ) - tensor.stop_gradient = not requires_grad + if requires_grad: + tensor.stop_gradient = False return tensor else: From 6d20a1777cbffcdf6d2d409eb2ce6945a671de62 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Wed, 6 Aug 2025 20:10:45 +0800 Subject: [PATCH 13/31] fix --- python/paddle/tensor/creation.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index 030cbec8e271a9..ba557d2ab9365c 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -1088,7 +1088,7 @@ def full_like( tensor = _C_ops.full_like(x, fill_value, dtype, device) else: tensor = _C_ops.full_like(x, fill_value, dtype, core.Place()) - if requires_grad: + if requires_grad is True: tensor.stop_gradient = False return tensor else: @@ -1307,7 +1307,7 @@ def ones( if device is not None and in_dynamic_mode(): tensor = tensor.to(device=device) - if requires_grad: + if requires_grad is True: tensor.stop_gradient = False return tensor @@ -1426,7 +1426,7 @@ def zeros( if device is not None and in_dynamic_mode(): tensor = tensor.to(device=device) - if requires_grad: + if requires_grad is True: tensor.stop_gradient = True return tensor @@ -1560,7 +1560,7 @@ def _check_attr(attr, message): else _current_expected_place() ), ) - if requires_grad: + if requires_grad is True: tensor.stop_gradient = False return tensor else: @@ -1678,7 +1678,7 @@ def full( ) if device is not None and in_dynamic_mode(): tensor = tensor.to(device=device) - if requires_grad: + if requires_grad is True: tensor.stop_gradient = False return tensor @@ -2651,7 +2651,7 @@ def empty( else _current_expected_place() ), ) - if requires_grad: + if requires_grad is True: tensor.stop_gradient = False return tensor else: @@ -2757,7 +2757,7 @@ def empty_like( convert_np_dtype_to_dtype_(dtype), (device or _current_expected_place()), ) - if requires_grad: + if requires_grad is True: tensor.stop_gradient = False return tensor From 2394bc958852a1f44bff4f3db93b5dfb8b2c4237 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Wed, 6 Aug 2025 21:44:47 +0800 Subject: [PATCH 14/31] fix --- python/paddle/tensor/creation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index ba557d2ab9365c..c73a212099f8ce 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -1427,7 +1427,7 @@ def zeros( if device is not None and in_dynamic_mode(): tensor = tensor.to(device=device) if requires_grad is True: - tensor.stop_gradient = True + tensor.stop_gradient = False return tensor From b75d2b2ff628e23ed3aa114d7f501d71d1b28e45 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Thu, 7 Aug 2025 11:11:57 +0800 Subject: [PATCH 15/31] fix for review --- python/paddle/tensor/creation.py | 27 +- test/legacy_test/test_creation.py | 461 +++++++++++++++--------------- 2 files changed, 233 insertions(+), 255 deletions(-) diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index c73a212099f8ce..c25f9384f057a8 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -1088,8 +1088,7 @@ def full_like( tensor = _C_ops.full_like(x, fill_value, dtype, device) else: tensor = _C_ops.full_like(x, fill_value, dtype, core.Place()) - if requires_grad is True: - tensor.stop_gradient = False + tensor.stop_gradient = not requires_grad return tensor else: helper = LayerHelper("full_like", **locals()) @@ -1305,10 +1304,9 @@ def ones( dtype = paddle.get_default_dtype() tensor = fill_constant(value=1.0, shape=shape, dtype=dtype, name=name) - if device is not None and in_dynamic_mode(): + if device is not None: tensor = tensor.to(device=device) - if requires_grad is True: - tensor.stop_gradient = False + tensor.stop_gradient = not requires_grad return tensor @@ -1424,10 +1422,9 @@ def zeros( dtype = paddle.get_default_dtype() tensor = fill_constant(value=0.0, shape=shape, dtype=dtype, name=name) - if device is not None and in_dynamic_mode(): + if device is not None: tensor = tensor.to(device=device) - if requires_grad is True: - tensor.stop_gradient = False + tensor.stop_gradient = not requires_grad return tensor @@ -1560,8 +1557,7 @@ def _check_attr(attr, message): else _current_expected_place() ), ) - if requires_grad is True: - tensor.stop_gradient = False + tensor.stop_gradient = not requires_grad return tensor else: helper = LayerHelper("eye", **locals()) @@ -1676,10 +1672,9 @@ def full( tensor = fill_constant( shape=shape, dtype=dtype, value=fill_value, name=name ) - if device is not None and in_dynamic_mode(): + if device is not None: tensor = tensor.to(device=device) - if requires_grad is True: - tensor.stop_gradient = False + tensor.stop_gradient = not requires_grad return tensor @@ -2651,8 +2646,7 @@ def empty( else _current_expected_place() ), ) - if requires_grad is True: - tensor.stop_gradient = False + tensor.stop_gradient = not requires_grad return tensor else: helper = LayerHelper("empty", **locals()) @@ -2757,8 +2751,7 @@ def empty_like( convert_np_dtype_to_dtype_(dtype), (device or _current_expected_place()), ) - if requires_grad is True: - tensor.stop_gradient = False + tensor.stop_gradient = not requires_grad return tensor else: diff --git a/test/legacy_test/test_creation.py b/test/legacy_test/test_creation.py index fd24f1618a0c7f..4da54f99ea8612 100644 --- a/test/legacy_test/test_creation.py +++ b/test/legacy_test/test_creation.py @@ -13,6 +13,7 @@ # limitations under the License. import unittest +from itertools import product from utils import dygraph_guard @@ -33,260 +34,244 @@ def setUp(self): self.dtypes = ["float32", paddle.float32, "int32", paddle.int32] def test_ones(self): - for device in self.devices: - for requires_grad in self.requires_grads: - for dtype in self.dtypes: - with dygraph_guard(): - x = paddle.ones( - [2], - dtype=dtype, - requires_grad=requires_grad, - device=device, - ) - self.assertEqual(x.place, device) - self.assertEqual(x.stop_gradient, not requires_grad) - if isinstance(dtype, paddle.dtype): - self.assertEqual(x.dtype, dtype) - st_f = paddle.jit.to_static( - paddle.ones, full_graph=True - ) - x = st_f( - [2], - dtype=dtype, - requires_grad=requires_grad, - device=device, - ) - self.assertEqual(x.stop_gradient, not requires_grad) - if isinstance(dtype, paddle.dtype): - self.assertEqual(x.dtype, dtype) + for device, requires_grad, dtype in product( + self.devices, self.requires_grads, self.dtypes + ): + with dygraph_guard(): + x = paddle.ones( + [2], + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + self.assertEqual(x.place, device) + self.assertEqual(x.stop_gradient, not requires_grad) + if isinstance(dtype, paddle.dtype): + self.assertEqual(x.dtype, dtype) + st_f = paddle.jit.to_static(paddle.ones, full_graph=True) + x = st_f( + [2], + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + self.assertEqual(x.stop_gradient, not requires_grad) + if isinstance(dtype, paddle.dtype): + self.assertEqual(x.dtype, dtype) def test_zeros(self): - for device in self.devices: - for requires_grad in self.requires_grads: - for dtype in self.dtypes: - with dygraph_guard(): - x = paddle.zeros( - [2], - dtype=dtype, - requires_grad=requires_grad, - device=device, - ) - self.assertEqual(x.place, device) - self.assertEqual(x.stop_gradient, not requires_grad) - if isinstance(dtype, paddle.dtype): - self.assertEqual(x.dtype, dtype) - st_f = paddle.jit.to_static( - paddle.zeros, full_graph=True - ) - x = st_f( - [2], - dtype=dtype, - requires_grad=requires_grad, - device=device, - ) - self.assertEqual(x.stop_gradient, not requires_grad) - if isinstance(dtype, paddle.dtype): - self.assertEqual(x.dtype, dtype) + for device, requires_grad, dtype in product( + self.devices, self.requires_grads, self.dtypes + ): + with dygraph_guard(): + x = paddle.zeros( + [2], + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + self.assertEqual(x.place, device) + self.assertEqual(x.stop_gradient, not requires_grad) + if isinstance(dtype, paddle.dtype): + self.assertEqual(x.dtype, dtype) + st_f = paddle.jit.to_static(paddle.zeros, full_graph=True) + x = st_f( + [2], + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + self.assertEqual(x.stop_gradient, not requires_grad) + if isinstance(dtype, paddle.dtype): + self.assertEqual(x.dtype, dtype) def test_full(self): - for device in self.devices: - for requires_grad in self.requires_grads: - for dtype in self.dtypes: - with dygraph_guard(): - x = paddle.full( - [2], - fill_value=3.14, - dtype=dtype, - requires_grad=requires_grad, - device=device, - ) - self.assertEqual(x.place, device) - self.assertEqual(x.stop_gradient, not requires_grad) - if isinstance(dtype, paddle.dtype): - self.assertEqual(x.dtype, dtype) - st_f = paddle.jit.to_static( - paddle.full, full_graph=True - ) - x = st_f( - [2], - fill_value=3.14, - dtype=dtype, - requires_grad=requires_grad, - device=device, - ) - self.assertEqual(x.stop_gradient, not requires_grad) - if isinstance(dtype, paddle.dtype): - self.assertEqual(x.dtype, dtype) + for device, requires_grad, dtype in product( + self.devices, self.requires_grads, self.dtypes + ): + with dygraph_guard(): + x = paddle.full( + [2], + fill_value=3.14, + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + self.assertEqual(x.place, device) + self.assertEqual(x.stop_gradient, not requires_grad) + if isinstance(dtype, paddle.dtype): + self.assertEqual(x.dtype, dtype) + st_f = paddle.jit.to_static(paddle.full, full_graph=True) + x = st_f( + [2], + fill_value=3.14, + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + self.assertEqual(x.stop_gradient, not requires_grad) + if isinstance(dtype, paddle.dtype): + self.assertEqual(x.dtype, dtype) def test_empty(self): - for device in self.devices: - for requires_grad in self.requires_grads: - for dtype in self.dtypes: - with dygraph_guard(): - x = paddle.empty( - [2], - dtype=dtype, - requires_grad=requires_grad, - device=device, - ) - self.assertEqual(x.place, device) - self.assertEqual(x.stop_gradient, not requires_grad) - if isinstance(dtype, paddle.dtype): - self.assertEqual(x.dtype, dtype) - st_f = paddle.jit.to_static( - paddle.empty, full_graph=True - ) - x = st_f( - [2], - dtype=dtype, - requires_grad=requires_grad, - device=device, - ) - self.assertEqual(x.stop_gradient, not requires_grad) - if isinstance(dtype, paddle.dtype): - self.assertEqual(x.dtype, dtype) + for device, requires_grad, dtype in product( + self.devices, self.requires_grads, self.dtypes + ): + with dygraph_guard(): + x = paddle.empty( + [2], + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + self.assertEqual(x.place, device) + self.assertEqual(x.stop_gradient, not requires_grad) + if isinstance(dtype, paddle.dtype): + self.assertEqual(x.dtype, dtype) + st_f = paddle.jit.to_static(paddle.empty, full_graph=True) + x = st_f( + [2], + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + self.assertEqual(x.stop_gradient, not requires_grad) + if isinstance(dtype, paddle.dtype): + self.assertEqual(x.dtype, dtype) def test_eye(self): - for device in self.devices: - for requires_grad in self.requires_grads: - for dtype in self.dtypes: - with dygraph_guard(): - x = paddle.eye( - 3, - 3, - dtype=dtype, - requires_grad=requires_grad, - device=device, - ) - self.assertEqual(x.place, device) - self.assertEqual(x.stop_gradient, not requires_grad) - if isinstance(dtype, paddle.dtype): - self.assertEqual(x.dtype, dtype) - st_f = paddle.jit.to_static(paddle.eye, full_graph=True) - x = st_f( - 3, - 3, - dtype=dtype, - requires_grad=requires_grad, - device=device, - ) - self.assertEqual(x.stop_gradient, not requires_grad) - if isinstance(dtype, paddle.dtype): - self.assertEqual(x.dtype, dtype) + for device, requires_grad, dtype in product( + self.devices, self.requires_grads, self.dtypes + ): + with dygraph_guard(): + x = paddle.eye( + 3, + 3, + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + self.assertEqual(x.place, device) + self.assertEqual(x.stop_gradient, not requires_grad) + if isinstance(dtype, paddle.dtype): + self.assertEqual(x.dtype, dtype) + st_f = paddle.jit.to_static(paddle.eye, full_graph=True) + x = st_f( + 3, + 3, + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + self.assertEqual(x.stop_gradient, not requires_grad) + if isinstance(dtype, paddle.dtype): + self.assertEqual(x.dtype, dtype) def test_ones_like(self): - for device in self.devices: - for requires_grad in self.requires_grads: - for dtype in self.dtypes: - with dygraph_guard(): - x = paddle.ones_like( - paddle.randn([2, 2]), - dtype=dtype, - requires_grad=requires_grad, - device=device, - ) - self.assertEqual(x.place, device) - self.assertEqual(x.stop_gradient, not requires_grad) - if isinstance(dtype, paddle.dtype): - self.assertEqual(x.dtype, dtype) - st_f = paddle.jit.to_static( - paddle.ones_like, full_graph=True - ) - x = st_f( - paddle.randn([2, 2]), - dtype=dtype, - requires_grad=requires_grad, - device=device, - ) - self.assertEqual(x.stop_gradient, not requires_grad) - if isinstance(dtype, paddle.dtype): - self.assertEqual(x.dtype, dtype) + for device, requires_grad, dtype in product( + self.devices, self.requires_grads, self.dtypes + ): + with dygraph_guard(): + x = paddle.ones_like( + paddle.randn([2, 2]), + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + self.assertEqual(x.place, device) + self.assertEqual(x.stop_gradient, not requires_grad) + if isinstance(dtype, paddle.dtype): + self.assertEqual(x.dtype, dtype) + st_f = paddle.jit.to_static(paddle.ones_like, full_graph=True) + x = st_f( + paddle.randn([2, 2]), + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + self.assertEqual(x.stop_gradient, not requires_grad) + if isinstance(dtype, paddle.dtype): + self.assertEqual(x.dtype, dtype) def test_zeros_like(self): - for device in self.devices: - for requires_grad in self.requires_grads: - for dtype in self.dtypes: - with dygraph_guard(): - x = paddle.zeros_like( - paddle.randn([2, 2]), - dtype=dtype, - requires_grad=requires_grad, - device=device, - ) - self.assertEqual(x.place, device) - self.assertEqual(x.stop_gradient, not requires_grad) - if isinstance(dtype, paddle.dtype): - self.assertEqual(x.dtype, dtype) - st_f = paddle.jit.to_static( - paddle.zeros_like, full_graph=True - ) - x = st_f( - paddle.randn([2, 2]), - dtype=dtype, - requires_grad=requires_grad, - device=device, - ) - self.assertEqual(x.stop_gradient, not requires_grad) - if isinstance(dtype, paddle.dtype): - self.assertEqual(x.dtype, dtype) + for device, requires_grad, dtype in product( + self.devices, self.requires_grads, self.dtypes + ): + with dygraph_guard(): + x = paddle.zeros_like( + paddle.randn([2, 2]), + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + self.assertEqual(x.place, device) + self.assertEqual(x.stop_gradient, not requires_grad) + if isinstance(dtype, paddle.dtype): + self.assertEqual(x.dtype, dtype) + st_f = paddle.jit.to_static(paddle.zeros_like, full_graph=True) + x = st_f( + paddle.randn([2, 2]), + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + self.assertEqual(x.stop_gradient, not requires_grad) + if isinstance(dtype, paddle.dtype): + self.assertEqual(x.dtype, dtype) def test_full_like(self): - for device in self.devices: - for requires_grad in self.requires_grads: - for dtype in self.dtypes: - with dygraph_guard(): - x = paddle.full_like( - paddle.randn([2, 2]), - 3.14, - dtype=dtype, - requires_grad=requires_grad, - device=device, - ) - self.assertEqual(x.place, device) - self.assertEqual(x.stop_gradient, not requires_grad) - if isinstance(dtype, paddle.dtype): - self.assertEqual(x.dtype, dtype) - st_f = paddle.jit.to_static( - paddle.full_like, full_graph=True - ) - x = st_f( - paddle.randn([2, 2]), - 3.14, - dtype=dtype, - requires_grad=requires_grad, - device=device, - ) - self.assertEqual(x.stop_gradient, not requires_grad) - if isinstance(dtype, paddle.dtype): - self.assertEqual(x.dtype, dtype) + for device, requires_grad, dtype in product( + self.devices, self.requires_grads, self.dtypes + ): + with dygraph_guard(): + x = paddle.full_like( + paddle.randn([2, 2]), + 3.14, + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + self.assertEqual(x.place, device) + self.assertEqual(x.stop_gradient, not requires_grad) + if isinstance(dtype, paddle.dtype): + self.assertEqual(x.dtype, dtype) + st_f = paddle.jit.to_static(paddle.full_like, full_graph=True) + x = st_f( + paddle.randn([2, 2]), + 3.14, + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + self.assertEqual(x.stop_gradient, not requires_grad) + if isinstance(dtype, paddle.dtype): + self.assertEqual(x.dtype, dtype) def test_empty_like(self): - for device in self.devices: - for requires_grad in self.requires_grads: - for dtype in self.dtypes: - with dygraph_guard(): - x = paddle.empty_like( - paddle.randn([2, 2]), - dtype=dtype, - requires_grad=requires_grad, - device=device, - ) - self.assertEqual(x.place, device) - self.assertEqual(x.stop_gradient, not requires_grad) - if isinstance(dtype, paddle.dtype): - self.assertEqual(x.dtype, dtype) - st_f = paddle.jit.to_static( - paddle.empty_like, full_graph=True - ) - x = st_f( - paddle.randn([2, 2]), - dtype=dtype, - requires_grad=requires_grad, - device=device, - ) - self.assertEqual(x.stop_gradient, not requires_grad) - if isinstance(dtype, paddle.dtype): - self.assertEqual(x.dtype, dtype) + for device, requires_grad, dtype in product( + self.devices, self.requires_grads, self.dtypes + ): + with dygraph_guard(): + x = paddle.empty_like( + paddle.randn([2, 2]), + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + self.assertEqual(x.place, device) + self.assertEqual(x.stop_gradient, not requires_grad) + if isinstance(dtype, paddle.dtype): + self.assertEqual(x.dtype, dtype) + st_f = paddle.jit.to_static(paddle.empty_like, full_graph=True) + x = st_f( + paddle.randn([2, 2]), + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + self.assertEqual(x.stop_gradient, not requires_grad) + if isinstance(dtype, paddle.dtype): + self.assertEqual(x.dtype, dtype) if __name__ == '__main__': From 2dc4029596c82ef65cac5d5a3add4a4fcc4d76de Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Thu, 7 Aug 2025 13:08:21 +0800 Subject: [PATCH 16/31] restore requires_grad setting --- python/paddle/tensor/creation.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index c25f9384f057a8..915ef97999b902 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -1088,7 +1088,8 @@ def full_like( tensor = _C_ops.full_like(x, fill_value, dtype, device) else: tensor = _C_ops.full_like(x, fill_value, dtype, core.Place()) - tensor.stop_gradient = not requires_grad + if requires_grad is True: + tensor.stop_gradient = False return tensor else: helper = LayerHelper("full_like", **locals()) @@ -1306,7 +1307,8 @@ def ones( if device is not None: tensor = tensor.to(device=device) - tensor.stop_gradient = not requires_grad + if requires_grad is True: + tensor.stop_gradient = False return tensor @@ -1424,7 +1426,8 @@ def zeros( if device is not None: tensor = tensor.to(device=device) - tensor.stop_gradient = not requires_grad + if requires_grad is True: + tensor.stop_gradient = False return tensor @@ -1557,7 +1560,8 @@ def _check_attr(attr, message): else _current_expected_place() ), ) - tensor.stop_gradient = not requires_grad + if requires_grad is True: + tensor.stop_gradient = False return tensor else: helper = LayerHelper("eye", **locals()) @@ -1674,7 +1678,8 @@ def full( ) if device is not None: tensor = tensor.to(device=device) - tensor.stop_gradient = not requires_grad + if requires_grad is True: + tensor.stop_gradient = False return tensor @@ -2646,7 +2651,8 @@ def empty( else _current_expected_place() ), ) - tensor.stop_gradient = not requires_grad + if requires_grad is True: + tensor.stop_gradient = False return tensor else: helper = LayerHelper("empty", **locals()) @@ -2751,7 +2757,8 @@ def empty_like( convert_np_dtype_to_dtype_(dtype), (device or _current_expected_place()), ) - tensor.stop_gradient = not requires_grad + if requires_grad is True: + tensor.stop_gradient = False return tensor else: From f534c70878189bf5ec53affa41dfe12032cd070e Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Thu, 7 Aug 2025 17:15:48 +0800 Subject: [PATCH 17/31] update 4 Tensor.new_xxx methods --- python/paddle/base/dygraph/math_op_patch.py | 83 ++++++++- python/paddle/pir/math_op_patch.py | 182 ++++++++++++++++++++ test/legacy_test/test_math_op_patch_pir.py | 78 +++++++++ 3 files changed, 342 insertions(+), 1 deletion(-) diff --git a/python/paddle/base/dygraph/math_op_patch.py b/python/paddle/base/dygraph/math_op_patch.py index bea78c7a528d9b..5930bd6ad311dc 100644 --- a/python/paddle/base/dygraph/math_op_patch.py +++ b/python/paddle/base/dygraph/math_op_patch.py @@ -25,7 +25,7 @@ if TYPE_CHECKING: from paddle import Tensor - from paddle._typing import DTypeLike + from paddle._typing import DTypeLike, PlaceLike _supported_int_dtype_ = [ core.VarDesc.VarType.UINT8, @@ -215,6 +215,83 @@ def _mT_(var: Tensor) -> Tensor: out = _C_ops.transpose(var, perm) return out + def _new_full_( + var: Tensor, + fill_value: bool | float | paddle.Tensor, + *, + dtype: DTypeLike | None = None, + device: PlaceLike | None = None, + requires_grad: bool = False, + ) -> Tensor: + if dtype is None: + dtype = var.dtype + if device is None: + device = var.device + + return paddle.full( + var.shape, + fill_value=fill_value, + dtype=dtype, + device=device, + requires_grad=requires_grad, + ) + + def _new_empty_( + var: Tensor, + *, + dtype: DTypeLike | None = None, + device: PlaceLike | None = None, + requires_grad: bool = False, + ) -> Tensor: + if dtype is None: + dtype = var.dtype + if device is None: + device = var.device + + return paddle.empty( + var.shape, dtype=dtype, device=device, requires_grad=requires_grad + ) + + def _new_ones_( + var: Tensor, + *, + dtype: DTypeLike | None = None, + device: PlaceLike | None = None, + requires_grad: bool = False, + ) -> Tensor: + if dtype is None: + dtype = var.dtype + if device is None: + device = var.device + + return paddle.full( + var.shape, + fill_value=1, + dtype=dtype, + device=device, + requires_grad=requires_grad, + ) + + def _new_zeros_( + var: Tensor, + *, + dtype: DTypeLike | None = None, + device: PlaceLike | None = None, + requires_grad: bool = False, + ) -> Tensor: + if dtype is None: + dtype = var.dtype + if device is None: + device = var.device + + return paddle.full( + var.shape, + fill_value=0, + dtype=dtype, + device=device, + requires_grad=requires_grad, + ) + eager_methods = [ ('__neg__', _neg_), ('__abs__', _abs_), @@ -231,6 +308,10 @@ def _mT_(var: Tensor) -> Tensor: ('size', _size_), ('T', _T_), ('mT', _mT_), + ('new_full', _new_full_), + ('new_empty', _new_empty_), + ('new_ones', _new_ones_), + ('new_zeros', _new_zeros_), # for logical compare ('__array_ufunc__', None), ] diff --git a/python/paddle/pir/math_op_patch.py b/python/paddle/pir/math_op_patch.py index f9f8bfcf733616..401ebbffc66fc9 100644 --- a/python/paddle/pir/math_op_patch.py +++ b/python/paddle/pir/math_op_patch.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import inspect import textwrap import warnings from functools import reduce +from typing import TYPE_CHECKING import numpy as np @@ -26,6 +28,10 @@ from . import Value +if TYPE_CHECKING: + from paddle._typing import DTypeLike, PlaceLike + + _already_patch_value = False _supported_int_dtype_ = [ @@ -566,6 +572,178 @@ def _mT_(self): return _C_ops.transpose(self, perm) + def _new_full_( + self, + fill_value: bool | float | paddle.Tensor, + *, + dtype: DTypeLike | None = None, + device: PlaceLike | None = None, + requires_grad: bool = False, + ): + """ + + Returns a Tensor of size size filled with fill_value. + By default, the returned Tensor has the same dtype and place as this tensor. + + Examples: + .. code-block:: python + + >>> import paddle + >>> paddle.enable_static() + + >>> x = paddle.ones(shape=[2, 3, 5]) + >>> x_new = x.new_full(3.14, dtype="float64", device="cpu") + + >>> exe = paddle.static.Executor() + >>> x_new_np = exe.run(paddle.static.default_main_program(), fetch_list=[x_new])[0] + >>> print(x_new_np.shape) + (2, 5, 3) + >>> print(str(x_new_np.dtype)) + 'paddle.float64' + >>> print(x_new_np.place) + Place(cpu) + """ + if dtype is None: + dtype = self.dtype + if device is None: + device = self.device + + return paddle.full( + self.shape, + fill_value=fill_value, + dtype=dtype, + device=device, + requires_grad=requires_grad, + ) + + def _new_empty_( + self, + fill_value: bool | float | paddle.Tensor, + *, + dtype: DTypeLike | None = None, + device: PlaceLike | None = None, + requires_grad: bool = False, + ): + """ + + Returns a Tensor of size size filled with fill_value. + By default, the returned Tensor has the same dtype and place as this tensor. + + Examples: + .. code-block:: python + + >>> import paddle + >>> paddle.enable_static() + + >>> x = paddle.ones(shape=[2, 3, 5]) + >>> x_new = x.new_empty(dtype="float64", device="cpu") + + >>> exe = paddle.static.Executor() + >>> x_new_np = exe.run(paddle.static.default_main_program(), fetch_list=[x_new])[0] + >>> print(x_new_np.shape) + (2, 5, 3) + >>> print(str(x_new_np.dtype)) + 'paddle.float64' + >>> print(x_new_np.place) + Place(cpu) + """ + if dtype is None: + dtype = self.dtype + if device is None: + device = self.device + + return paddle.empty( + self.shape, dtype=dtype, device=device, requires_grad=requires_grad + ) + + def _new_ones_( + self, + fill_value: bool | float | paddle.Tensor, + *, + dtype: DTypeLike | None = None, + device: PlaceLike | None = None, + requires_grad: bool = False, + ): + """ + + Returns a Tensor of size size filled with fill_value. + By default, the returned Tensor has the same dtype and place as this tensor. + + Examples: + .. code-block:: python + + >>> import paddle + >>> paddle.enable_static() + + >>> x = paddle.ones(shape=[2, 3, 5]) + >>> x_new = x.new_ones(3.14, dtype="float64", device="cpu") + + >>> exe = paddle.static.Executor() + >>> x_new_np = exe.run(paddle.static.default_main_program(), fetch_list=[x_new])[0] + >>> print(x_new_np.shape) + (2, 5, 3) + >>> print(str(x_new_np.dtype)) + 'paddle.float64' + >>> print(x_new_np.place) + Place(cpu) + """ + if dtype is None: + dtype = self.dtype + if device is None: + device = self.device + + return paddle.full( + self.shape, + fill_value=1, + dtype=dtype, + device=device, + requires_grad=requires_grad, + ) + + def _new_zeros_( + self, + fill_value: bool | float | paddle.Tensor, + *, + dtype: DTypeLike | None = None, + device: PlaceLike | None = None, + requires_grad: bool = False, + ): + """ + + Returns a Tensor of size size filled with fill_value. + By default, the returned Tensor has the same dtype and place as this tensor. + + Examples: + .. code-block:: python + + >>> import paddle + >>> paddle.enable_static() + + >>> x = paddle.ones(shape=[2, 3, 5]) + >>> x_new = x.new_zeros(3.14, dtype="float64", device="cpu") + + >>> exe = paddle.static.Executor() + >>> x_new_np = exe.run(paddle.static.default_main_program(), fetch_list=[x_new])[0] + >>> print(x_new_np.shape) + (2, 5, 3) + >>> print(str(x_new_np.dtype)) + 'paddle.float64' + >>> print(x_new_np.place) + Place(cpu) + """ + if dtype is None: + dtype = self.dtype + if device is None: + device = self.device + + return paddle.full( + self.shape, + fill_value=0, + dtype=dtype, + device=device, + requires_grad=requires_grad, + ) + def _int_(self): error_msg = """\ int(Tensor) is not supported in static graph mode. Because it's value is not available during the static mode. @@ -1112,6 +1290,10 @@ def register_hook(self, hook): ('size', _size_), ('T', _T_), ('mT', _mT_), + ('new_full', _new_full_), + ('new_empty', _new_empty_), + ('new_ones', _new_ones_), + ('new_zeros', _new_zeros_), ('clone', clone), ('clear_gradient', clear_gradient), ('append', append), diff --git a/test/legacy_test/test_math_op_patch_pir.py b/test/legacy_test/test_math_op_patch_pir.py index 3035ce03dbb551..679392969aee32 100644 --- a/test/legacy_test/test_math_op_patch_pir.py +++ b/test/legacy_test/test_math_op_patch_pir.py @@ -725,6 +725,84 @@ def test_mT(self): np.testing.assert_array_equal(y_mT_np.shape, (2, 4, 3)) np.testing.assert_array_equal(z_mT_np.shape, (100, 5, 13, 12)) + def test_new_xxx(self): + with paddle.pir_utils.IrGuard(): + shape = [1] + x = paddle.rand(shape, dtype="float32") + self.assertRaises(ValueError, getattr, x, 'mT') + + for ndim in range(2, 5): + # shape is [1, 2], [1, 2, 3], [1, 2, 3, 4] + shape = list(range(1, ndim + 1)) + out_shape = list(shape) + out_shape[-2], out_shape[-1] = out_shape[-1], out_shape[-2] + main_program, exe, program_guard = new_program() + with program_guard: + x = paddle.rand(shape, dtype="float32") + x_new = x.new_full([7], 1.0) + self.assertEqual(x_new.shape, out_shape) + (output_x,) = exe.run(main_program, fetch_list=[x_new]) + self.assertEqual(output_x.shape, (7)) + + shape = [1, 2, 3, 0, 1] + out_shape = list(shape) + out_shape[-2], out_shape[-1] = out_shape[-1], out_shape[-2] + main_program, exe, program_guard = new_program() + with program_guard: + x = paddle.rand(shape, dtype="float32") + x_new = x.new_full([3, 0], 4.0) + self.assertEqual(x_new.shape, out_shape) + (output_x,) = exe.run(main_program, fetch_list=[x_new]) + self.assertEqual(output_x.shape, (3, 0)) + + shape = [1, 2, 3, 1, 0] + out_shape = list(shape) + out_shape[-2], out_shape[-1] = out_shape[-1], out_shape[-2] + main_program, exe, program_guard = new_program() + with program_guard: + x = paddle.rand(shape, dtype="float32") + x_new = x.new_empty([2, 2]) + self.assertEqual(x_new.shape, out_shape) + (output_x,) = exe.run(main_program, fetch_list=[x_new]) + self.assertEqual(output_x.shape, (2, 2)) + + shape = [1, 2, 3, 0, 0] + out_shape = list(shape) + out_shape[-2], out_shape[-1] = out_shape[-1], out_shape[-2] + main_program, exe, program_guard = new_program() + with program_guard: + x = paddle.rand(shape, dtype="float32") + x_new = x.new_ones([2, 2]) + self.assertEqual(x_new.shape, out_shape) + (output_x,) = exe.run(main_program, fetch_list=[x_new]) + self.assertEqual(output_x.shape, (2, 2)) + + shape = [0, 2, 3, 0, 0] + out_shape = list(shape) + out_shape[-2], out_shape[-1] = out_shape[-1], out_shape[-2] + main_program, exe, program_guard = new_program() + with program_guard: + x = paddle.rand(shape, dtype="float32") + x_new = x.new_zeros([2, 3]) + self.assertEqual(x_new.shape, out_shape) + (output_x,) = exe.run(main_program, fetch_list=[x_new]) + self.assertEqual(output_x.shape, (2, 3)) + + # test mT with dynamic shape + with paddle.pir_utils.IrGuard(): + main_program, exe, program_guard = new_program() + with program_guard: + x = paddle.static.data(name="x", shape=[-1, 5], dtype='float32') + x_new = x.new_ones([2, 2]) + + x_np = np.random.randn(12, 5).astype('float32') + (x_new_np,) = exe.run( + main_program, + feed={"x": x_np}, + fetch_list=[x_new], + ) + np.testing.assert_array_equal(x_new_np.shape, (2, 2)) + def test_hash(self): with paddle.pir_utils.IrGuard(): _, _, program_guard = new_program() From d0abf82144ed0eb8ad43a2f62f0ec9ceb3710b73 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Fri, 8 Aug 2025 11:15:12 +0800 Subject: [PATCH 18/31] fix name --- python/paddle/tensor/creation.py | 34 ++++++++++++++++---------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index 915ef97999b902..7ba5a1bc40dfc4 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -1038,10 +1038,10 @@ def full_like( x: paddle.Tensor, fill_value: bool | float, dtype: DTypeLike | None = None, - name: str | None = None, *, device: PlaceLike | None = None, requires_grad: bool = False, + name: str | None = None, ) -> paddle.Tensor: """ @@ -1054,11 +1054,11 @@ def full_like( dtype(np.dtype|str, optional): The data type of output. The data type can be one of bool, float16, float32, float64, int32, int64. The default value is None, which means the output data type is the same as input. - name(str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. device(PlaceLike|None, optional): The desired device of returned tensor. if None, uses the current device for the default tensor type (see paddle.device.set_device()). device will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types. Default: None. requires_grad(bool, optional): If autograd should record operations on the returned tensor. Default: False. + name(str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. Returns: Tensor: Tensor which is created according to ``x``, ``fill_value`` and ``dtype``. @@ -1250,10 +1250,10 @@ def fill_constant( def ones( shape: ShapeLike, dtype: DTypeLike | None = None, - name: str | None = None, *, device: PlaceLike | None = None, requires_grad: bool = False, + name: str | None = None, ) -> paddle.Tensor: """ Create a Tensor of specified :attr:`shape` and :attr:`dtype` and fill it with 1. @@ -1264,11 +1264,11 @@ def ones( If ``shape`` is an Tensor, it should be an 1-D Tensor which represents a list. dtype (np.dtype|str, optional): Data type of output Tensor, it should be one of bool, float16, float32, float64, int32 and int64. If it is set to None, the data type will be float32. - name(str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. device(PlaceLike|None, optional): The desired device of returned tensor. if None, uses the current device for the default tensor type (see paddle.device.set_device()). device will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types. Default: None. requires_grad(bool, optional): If autograd should record operations on the returned tensor. Default: False. + name(str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. Returns: Tensor: A Tensor of data type :attr:`dtype` with shape :attr:`shape` and all elements are 1. @@ -1315,10 +1315,10 @@ def ones( def ones_like( x: paddle.Tensor, dtype: DTypeLike | None = None, - name: str | None = None, *, device: PlaceLike | None = None, requires_grad: bool = False, + name: str | None = None, ) -> paddle.Tensor: """ Returns a Tensor filled with the value 1, with the same shape and @@ -1331,11 +1331,11 @@ def ones_like( output tensor. Supported data types: bool, float16, float32, float64, int32, int64. If ``dtype`` is None, the data type is the same as ``x``. Default is None. - name(str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. device(PlaceLike|None, optional): The desired device of returned tensor. if None, uses the current device for the default tensor type (see paddle.device.set_device()). device will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types. Default: None. requires_grad(bool, optional): If autograd should record operations on the returned tensor. Default: False. + name(str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. Returns: Tensor: A Tensor filled with the value 1, with the same shape and @@ -1368,10 +1368,10 @@ def ones_like( def zeros( shape: ShapeLike, dtype: DTypeLike | None = None, - name: str | None = None, *, device: PlaceLike | None = None, requires_grad: bool = False, + name: str | None = None, ) -> paddle.Tensor: """ Creates a tensor of specified :attr:`shape` and :attr:`dtype`, and fills it with 0. @@ -1434,10 +1434,10 @@ def zeros( def zeros_like( x: paddle.Tensor, dtype: DTypeLike | None = None, - name: str | None = None, *, device: PlaceLike | None = None, requires_grad: bool = False, + name: str | None = None, ) -> paddle.Tensor: """ Returns a Tensor filled with the value 0, with the same shape and @@ -1450,11 +1450,11 @@ def zeros_like( output tensor. Supported data types: bool, float16, float32, float64, int32, int64. If ``dtype`` is None, the data type is the same as ``x``. Default is None. - name(str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. device(PlaceLike|None, optional): The desired device of returned tensor. if None, uses the current device for the default tensor type (see paddle.device.set_device()). device will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types. Default: None. requires_grad(bool, optional): If autograd should record operations on the returned tensor. Default: False. + name(str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. Returns: Tensor: A Tensor filled with the value 0, with the same shape and @@ -1489,10 +1489,10 @@ def eye( num_rows: int, num_columns: int | None = None, dtype: DTypeLike | None = None, - name: str | None = None, *, device: PlaceLike | None = None, requires_grad: bool = False, + name: str | None = None, ) -> paddle.Tensor: """ @@ -1505,11 +1505,11 @@ def eye( dtype(np.dtype|str, optional): The data type of the returned Tensor. It should be int32, int64, float16, float32, float64, complex64, complex128. Default: if None, the data type is float32. - name(str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. device(PlaceLike|None, optional): The desired device of returned tensor. if None, uses the current device for the default tensor type (see paddle.device.set_device()). device will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types. Default: None. requires_grad(bool, optional): If autograd should record operations on the returned tensor. Default: False. + name(str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. Returns: Tensor: An identity Tensor or DenseTensor of shape [num_rows, num_columns]. @@ -1601,10 +1601,10 @@ def full( shape: ShapeLike, fill_value: bool | float | paddle.Tensor, dtype: DTypeLike | None = None, - name: str | None = None, *, device: PlaceLike | None = None, requires_grad: bool = False, + name: str | None = None, ) -> paddle.Tensor: """ @@ -1619,11 +1619,11 @@ def full( dtype(np.dtype|str, optional): Data type of the output Tensor which can be float16, float32, float64, int32, int64, if dtype is `None`, the data type of created Tensor is `float32`. - name(str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. device(PlaceLike|None, optional): The desired device of returned tensor. if None, uses the current device for the default tensor type (see paddle.device.set_device()). device will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types. Default: None. requires_grad(bool, optional): If autograd should record operations on the returned tensor. Default: False. + name(str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. Returns: Tensor: Tensor which is created according to ``shape``, ``fill_value`` and ``dtype``. @@ -2545,10 +2545,10 @@ def diag( def empty( shape: ShapeLike, dtype: DTypeLike | None = None, - name: str | None = None, *, device: PlaceLike | None = None, requires_grad: bool = False, + name: str | None = None, ) -> paddle.Tensor: """ Returns a Tensor with uninitialized data which size is same as ``shape``. @@ -2561,11 +2561,11 @@ def empty( which can be bool, float16, float32, float64, int32, int64, complex64, complex128 if dtype is `None`, the data type of created Tensor use global default dtype (see ``get_default_dtype`` for details). - name(str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. device(PlaceLike|None, optional): The desired device of returned tensor. if None, uses the current device for the default tensor type (see paddle.device.set_device()). device will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types. Default: None. requires_grad(bool, optional): If autograd should record operations on the returned tensor. Default: False. + name(str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. Returns: Tensor: Tensor which is created according to ``shape`` and ``dtype``, and is uninitialized. @@ -2703,10 +2703,10 @@ def empty( def empty_like( x: paddle.Tensor, dtype: DTypeLike | None = None, - name: str | None = None, *, device: PlaceLike | None = None, requires_grad: bool = False, + name: str | None = None, ) -> paddle.Tensor: """ Returns a Tensor with uninitialized data which has identical shape of ``x`` and ``dtype``. @@ -2717,11 +2717,11 @@ def empty_like( dtype(np.dtype|str, optional): The data type of output. The data type can be one of bool, float16, float32, float64, int32, int64. The default value is None, which means the output data type is the same as input. - name(str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. device(PlaceLike|None, optional): The desired device of returned tensor. if None, uses the current device for the default tensor type (see paddle.device.set_device()). device will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types. Default: None. requires_grad(bool, optional): If autograd should record operations on the returned tensor. Default: False. + name(str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. Returns: Tensor: Tensor which is created according to ``x`` and ``dtype``, and is uninitialized. From 16e5aa999d58f9bb62abae44af1a9d697c3bc154 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Fri, 8 Aug 2025 11:16:42 +0800 Subject: [PATCH 19/31] use full instead of fill_constant --- python/paddle/tensor/creation.py | 46 +++++++++++++++++--------------- 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index 7ba5a1bc40dfc4..c23dede56ec3f0 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -24,6 +24,7 @@ import paddle from paddle import _C_ops +from paddle.device import _convert_to_place from paddle.utils.inplace_utils import inplace_apis_in_dygraph_only from ..base.data_feeder import ( @@ -1141,11 +1142,16 @@ def fill_constant( value: bool | float | paddle.Tensor, force_cpu: bool = False, out: paddle.Tensor | None = None, + place: PlaceLike | None = None, name: str | None = None, ) -> paddle.Tensor: shape = [shape] if isinstance(shape, int) else shape if in_dynamic_or_pir_mode(): - place = _current_expected_place() + if place is None: + place = _current_expected_place() + else: + place = _convert_to_place(place) + if force_cpu: place = core.CPUPlace() @@ -1301,15 +1307,14 @@ def ones( [1. 1.] [1. 1.]] """ - if dtype is None: - dtype = paddle.get_default_dtype() - tensor = fill_constant(value=1.0, shape=shape, dtype=dtype, name=name) - - if device is not None: - tensor = tensor.to(device=device) - if requires_grad is True: - tensor.stop_gradient = False - return tensor + return full( + shape, + 1, + dtype, + device=device, + requires_grad=requires_grad, + name=name, + ) def ones_like( @@ -1420,15 +1425,14 @@ def zeros( [0. 0.] [0. 0.]] """ - if dtype is None: - dtype = paddle.get_default_dtype() - tensor = fill_constant(value=0.0, shape=shape, dtype=dtype, name=name) - - if device is not None: - tensor = tensor.to(device=device) - if requires_grad is True: - tensor.stop_gradient = False - return tensor + return full( + shape, + 0, + dtype, + device=device, + requires_grad=requires_grad, + name=name, + ) def zeros_like( @@ -1674,10 +1678,8 @@ def full( dtype = paddle.get_default_dtype() tensor = fill_constant( - shape=shape, dtype=dtype, value=fill_value, name=name + shape=shape, dtype=dtype, value=fill_value, place=device, name=name ) - if device is not None: - tensor = tensor.to(device=device) if requires_grad is True: tensor.stop_gradient = False return tensor From 588dc00013a715d2c98186ca6a47e7f749f68b87 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Fri, 8 Aug 2025 11:42:39 +0800 Subject: [PATCH 20/31] refine device --- python/paddle/device/__init__.py | 5 ++++- python/paddle/tensor/creation.py | 8 ++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/python/paddle/device/__init__.py b/python/paddle/device/__init__.py index f50db1c25393bf..0fb299221b07fb 100644 --- a/python/paddle/device/__init__.py +++ b/python/paddle/device/__init__.py @@ -213,7 +213,10 @@ def get_cudnn_version() -> int | None: return _cudnn_version -def _convert_to_place(device: str) -> PlaceLike: +def _convert_to_place(device: PlaceLike) -> PlaceLike: + if not isinstance(device, str): + return device # return directly if not a string + lower_device = device.lower() if device in core.get_all_custom_device_type(): selected_devices = os.getenv(f"FLAGS_selected_{device}s", "0").split( diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index c23dede56ec3f0..115c2092fadf5a 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -1559,8 +1559,8 @@ def _check_attr(attr, message): num_columns, dtype, ( - device - if (in_dynamic_mode() and device is not None) + _convert_to_place(device) + if device is not None else _current_expected_place() ), ) @@ -2648,8 +2648,8 @@ def empty( shape, convert_np_dtype_to_dtype_(dtype), ( - device - if (in_dynamic_mode() and device is not None) + _convert_to_place(device) + if device is not None else _current_expected_place() ), ) From 2139a685cff0d6b83c8173000a8af5818cf3b4cc Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Fri, 8 Aug 2025 11:16:42 +0800 Subject: [PATCH 21/31] use full instead of fill_constant --- python/paddle/device/__init__.py | 5 +++- python/paddle/tensor/creation.py | 46 +++++++++++++++++--------------- 2 files changed, 28 insertions(+), 23 deletions(-) diff --git a/python/paddle/device/__init__.py b/python/paddle/device/__init__.py index f50db1c25393bf..0fb299221b07fb 100644 --- a/python/paddle/device/__init__.py +++ b/python/paddle/device/__init__.py @@ -213,7 +213,10 @@ def get_cudnn_version() -> int | None: return _cudnn_version -def _convert_to_place(device: str) -> PlaceLike: +def _convert_to_place(device: PlaceLike) -> PlaceLike: + if not isinstance(device, str): + return device # return directly if not a string + lower_device = device.lower() if device in core.get_all_custom_device_type(): selected_devices = os.getenv(f"FLAGS_selected_{device}s", "0").split( diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index 7ba5a1bc40dfc4..c23dede56ec3f0 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -24,6 +24,7 @@ import paddle from paddle import _C_ops +from paddle.device import _convert_to_place from paddle.utils.inplace_utils import inplace_apis_in_dygraph_only from ..base.data_feeder import ( @@ -1141,11 +1142,16 @@ def fill_constant( value: bool | float | paddle.Tensor, force_cpu: bool = False, out: paddle.Tensor | None = None, + place: PlaceLike | None = None, name: str | None = None, ) -> paddle.Tensor: shape = [shape] if isinstance(shape, int) else shape if in_dynamic_or_pir_mode(): - place = _current_expected_place() + if place is None: + place = _current_expected_place() + else: + place = _convert_to_place(place) + if force_cpu: place = core.CPUPlace() @@ -1301,15 +1307,14 @@ def ones( [1. 1.] [1. 1.]] """ - if dtype is None: - dtype = paddle.get_default_dtype() - tensor = fill_constant(value=1.0, shape=shape, dtype=dtype, name=name) - - if device is not None: - tensor = tensor.to(device=device) - if requires_grad is True: - tensor.stop_gradient = False - return tensor + return full( + shape, + 1, + dtype, + device=device, + requires_grad=requires_grad, + name=name, + ) def ones_like( @@ -1420,15 +1425,14 @@ def zeros( [0. 0.] [0. 0.]] """ - if dtype is None: - dtype = paddle.get_default_dtype() - tensor = fill_constant(value=0.0, shape=shape, dtype=dtype, name=name) - - if device is not None: - tensor = tensor.to(device=device) - if requires_grad is True: - tensor.stop_gradient = False - return tensor + return full( + shape, + 0, + dtype, + device=device, + requires_grad=requires_grad, + name=name, + ) def zeros_like( @@ -1674,10 +1678,8 @@ def full( dtype = paddle.get_default_dtype() tensor = fill_constant( - shape=shape, dtype=dtype, value=fill_value, name=name + shape=shape, dtype=dtype, value=fill_value, place=device, name=name ) - if device is not None: - tensor = tensor.to(device=device) if requires_grad is True: tensor.stop_gradient = False return tensor From 8eeced4409f6bcecf7f0fb0cb824dec7ef1be241 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Fri, 8 Aug 2025 13:01:56 +0800 Subject: [PATCH 22/31] fix --- test/legacy_test/test_fill_any_like_op.py | 2 +- test/legacy_test/test_full_like_op.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/legacy_test/test_fill_any_like_op.py b/test/legacy_test/test_fill_any_like_op.py index a60ab183e36cd8..e9a23036594345 100644 --- a/test/legacy_test/test_fill_any_like_op.py +++ b/test/legacy_test/test_fill_any_like_op.py @@ -37,7 +37,7 @@ def fill_any_like_wrapper(x, value, out_dtype=None, name=None): out_dtype, paddle.framework.core.VarDesc.VarType ): tmp_dtype = paddle.pir.core.vartype_to_datatype[tmp_dtype] - return paddle.full_like(x, value, tmp_dtype, name) + return paddle.full_like(x, value, tmp_dtype, name=name) class TestFillAnyLikeOp(OpTest): diff --git a/test/legacy_test/test_full_like_op.py b/test/legacy_test/test_full_like_op.py index 77b7941240051b..72019d8b5caea6 100644 --- a/test/legacy_test/test_full_like_op.py +++ b/test/legacy_test/test_full_like_op.py @@ -38,7 +38,7 @@ def fill_any_like_wrapper(x, value, out_dtype=None, name=None): out_dtype, paddle.framework.core.VarDesc.VarType ): tmp_dtype = paddle.pir.core.vartype_to_datatype[tmp_dtype] - return paddle.full_like(x, value, tmp_dtype, name) + return paddle.full_like(x, value, tmp_dtype, name=name) class TestFullOp(unittest.TestCase): From 7c15715d8a67183b84adcfd099c8a1efa252c471 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Fri, 8 Aug 2025 13:01:56 +0800 Subject: [PATCH 23/31] fix --- test/legacy_test/test_fill_any_like_op.py | 2 +- test/legacy_test/test_full_like_op.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/legacy_test/test_fill_any_like_op.py b/test/legacy_test/test_fill_any_like_op.py index a60ab183e36cd8..e9a23036594345 100644 --- a/test/legacy_test/test_fill_any_like_op.py +++ b/test/legacy_test/test_fill_any_like_op.py @@ -37,7 +37,7 @@ def fill_any_like_wrapper(x, value, out_dtype=None, name=None): out_dtype, paddle.framework.core.VarDesc.VarType ): tmp_dtype = paddle.pir.core.vartype_to_datatype[tmp_dtype] - return paddle.full_like(x, value, tmp_dtype, name) + return paddle.full_like(x, value, tmp_dtype, name=name) class TestFillAnyLikeOp(OpTest): diff --git a/test/legacy_test/test_full_like_op.py b/test/legacy_test/test_full_like_op.py index 77b7941240051b..72019d8b5caea6 100644 --- a/test/legacy_test/test_full_like_op.py +++ b/test/legacy_test/test_full_like_op.py @@ -38,7 +38,7 @@ def fill_any_like_wrapper(x, value, out_dtype=None, name=None): out_dtype, paddle.framework.core.VarDesc.VarType ): tmp_dtype = paddle.pir.core.vartype_to_datatype[tmp_dtype] - return paddle.full_like(x, value, tmp_dtype, name) + return paddle.full_like(x, value, tmp_dtype, name=name) class TestFullOp(unittest.TestCase): From 7f50ef26516e208d0ef5138024dd07c373540d8b Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Fri, 8 Aug 2025 22:50:19 +0800 Subject: [PATCH 24/31] fix string device --- python/paddle/tensor/creation.py | 10 ++++- test/legacy_test/test_creation.py | 67 ++++++++++++++++++++++--------- 2 files changed, 56 insertions(+), 21 deletions(-) diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index 115c2092fadf5a..4d9041f666575d 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -1086,7 +1086,9 @@ def full_like( if in_dynamic_or_pir_mode(): if in_dynamic_mode(): - tensor = _C_ops.full_like(x, fill_value, dtype, device) + tensor = _C_ops.full_like( + x, fill_value, dtype, _convert_to_place(device) + ) else: tensor = _C_ops.full_like(x, fill_value, dtype, core.Place()) if requires_grad is True: @@ -2757,7 +2759,11 @@ def empty_like( tensor = _C_ops.empty( x_shape, convert_np_dtype_to_dtype_(dtype), - (device or _current_expected_place()), + ( + _convert_to_place(device) + if device is not None + else _current_expected_place() + ), ) if requires_grad is True: tensor.stop_gradient = False diff --git a/test/legacy_test/test_creation.py b/test/legacy_test/test_creation.py index 4da54f99ea8612..965fe145aa8a7f 100644 --- a/test/legacy_test/test_creation.py +++ b/test/legacy_test/test_creation.py @@ -22,9 +22,11 @@ class TestTensorCreation(unittest.TestCase): def setUp(self): - self.devices = [paddle.CPUPlace()] + self.devices = [paddle.CPUPlace(), "cpu"] if paddle.device.is_compiled_with_cuda(): self.devices.append(paddle.CUDAPlace(0)) + self.devices.append("gpu") + self.devices.append("gpu:0") if paddle.device.is_compiled_with_xpu(): self.devices.append(paddle.device.XPUPlace(0)) if paddle.device.is_compiled_with_ipu(): @@ -44,11 +46,14 @@ def test_ones(self): requires_grad=requires_grad, device=device, ) - self.assertEqual(x.place, device) + if isinstance(device, paddle.framework.core.Place): + self.assertEqual(x.place, device) self.assertEqual(x.stop_gradient, not requires_grad) if isinstance(dtype, paddle.dtype): self.assertEqual(x.dtype, dtype) - st_f = paddle.jit.to_static(paddle.ones, full_graph=True) + st_f = paddle.jit.to_static( + paddle.ones, full_graph=True, backend=None + ) x = st_f( [2], dtype=dtype, @@ -70,11 +75,14 @@ def test_zeros(self): requires_grad=requires_grad, device=device, ) - self.assertEqual(x.place, device) + if isinstance(device, paddle.framework.core.Place): + self.assertEqual(x.place, device) self.assertEqual(x.stop_gradient, not requires_grad) if isinstance(dtype, paddle.dtype): self.assertEqual(x.dtype, dtype) - st_f = paddle.jit.to_static(paddle.zeros, full_graph=True) + st_f = paddle.jit.to_static( + paddle.zeros, full_graph=True, backend=None + ) x = st_f( [2], dtype=dtype, @@ -97,11 +105,14 @@ def test_full(self): requires_grad=requires_grad, device=device, ) - self.assertEqual(x.place, device) + if isinstance(device, paddle.framework.core.Place): + self.assertEqual(x.place, device) self.assertEqual(x.stop_gradient, not requires_grad) if isinstance(dtype, paddle.dtype): self.assertEqual(x.dtype, dtype) - st_f = paddle.jit.to_static(paddle.full, full_graph=True) + st_f = paddle.jit.to_static( + paddle.full, full_graph=True, backend=None + ) x = st_f( [2], fill_value=3.14, @@ -124,11 +135,14 @@ def test_empty(self): requires_grad=requires_grad, device=device, ) - self.assertEqual(x.place, device) + if isinstance(device, paddle.framework.core.Place): + self.assertEqual(x.place, device) self.assertEqual(x.stop_gradient, not requires_grad) if isinstance(dtype, paddle.dtype): self.assertEqual(x.dtype, dtype) - st_f = paddle.jit.to_static(paddle.empty, full_graph=True) + st_f = paddle.jit.to_static( + paddle.empty, full_graph=True, backend=None + ) x = st_f( [2], dtype=dtype, @@ -151,11 +165,14 @@ def test_eye(self): requires_grad=requires_grad, device=device, ) - self.assertEqual(x.place, device) + if isinstance(device, paddle.framework.core.Place): + self.assertEqual(x.place, device) self.assertEqual(x.stop_gradient, not requires_grad) if isinstance(dtype, paddle.dtype): self.assertEqual(x.dtype, dtype) - st_f = paddle.jit.to_static(paddle.eye, full_graph=True) + st_f = paddle.jit.to_static( + paddle.eye, full_graph=True, backend=None + ) x = st_f( 3, 3, @@ -178,11 +195,14 @@ def test_ones_like(self): requires_grad=requires_grad, device=device, ) - self.assertEqual(x.place, device) + if isinstance(device, paddle.framework.core.Place): + self.assertEqual(x.place, device) self.assertEqual(x.stop_gradient, not requires_grad) if isinstance(dtype, paddle.dtype): self.assertEqual(x.dtype, dtype) - st_f = paddle.jit.to_static(paddle.ones_like, full_graph=True) + st_f = paddle.jit.to_static( + paddle.ones_like, full_graph=True, backend=None + ) x = st_f( paddle.randn([2, 2]), dtype=dtype, @@ -204,11 +224,14 @@ def test_zeros_like(self): requires_grad=requires_grad, device=device, ) - self.assertEqual(x.place, device) + if isinstance(device, paddle.framework.core.Place): + self.assertEqual(x.place, device) self.assertEqual(x.stop_gradient, not requires_grad) if isinstance(dtype, paddle.dtype): self.assertEqual(x.dtype, dtype) - st_f = paddle.jit.to_static(paddle.zeros_like, full_graph=True) + st_f = paddle.jit.to_static( + paddle.zeros_like, full_graph=True, backend=None + ) x = st_f( paddle.randn([2, 2]), dtype=dtype, @@ -231,11 +254,14 @@ def test_full_like(self): requires_grad=requires_grad, device=device, ) - self.assertEqual(x.place, device) + if isinstance(device, paddle.framework.core.Place): + self.assertEqual(x.place, device) self.assertEqual(x.stop_gradient, not requires_grad) if isinstance(dtype, paddle.dtype): self.assertEqual(x.dtype, dtype) - st_f = paddle.jit.to_static(paddle.full_like, full_graph=True) + st_f = paddle.jit.to_static( + paddle.full_like, full_graph=True, backend=None + ) x = st_f( paddle.randn([2, 2]), 3.14, @@ -258,11 +284,14 @@ def test_empty_like(self): requires_grad=requires_grad, device=device, ) - self.assertEqual(x.place, device) + if isinstance(device, paddle.framework.core.Place): + self.assertEqual(x.place, device) self.assertEqual(x.stop_gradient, not requires_grad) if isinstance(dtype, paddle.dtype): self.assertEqual(x.dtype, dtype) - st_f = paddle.jit.to_static(paddle.empty_like, full_graph=True) + st_f = paddle.jit.to_static( + paddle.empty_like, full_graph=True, backend=None + ) x = st_f( paddle.randn([2, 2]), dtype=dtype, From 518b9b79f50257407dac56d4851f927c7c960d43 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Sat, 9 Aug 2025 18:21:29 +0800 Subject: [PATCH 25/31] add pir mothods --- python/paddle/base/dygraph/math_op_patch.py | 29 ++- python/paddle/pir/math_op_patch.py | 23 +-- test/legacy_test/test_creation.py | 196 ++++++++++++++++++++ 3 files changed, 227 insertions(+), 21 deletions(-) diff --git a/python/paddle/base/dygraph/math_op_patch.py b/python/paddle/base/dygraph/math_op_patch.py index 5930bd6ad311dc..322671f0abbb3b 100644 --- a/python/paddle/base/dygraph/math_op_patch.py +++ b/python/paddle/base/dygraph/math_op_patch.py @@ -18,6 +18,7 @@ import numpy as np +import paddle from paddle import _C_ops from .. import core @@ -25,7 +26,7 @@ if TYPE_CHECKING: from paddle import Tensor - from paddle._typing import DTypeLike, PlaceLike + from paddle._typing import DTypeLike, PlaceLike, ShapeLike _supported_int_dtype_ = [ core.VarDesc.VarType.UINT8, @@ -70,6 +71,7 @@ def monkey_patch_math_tensor(): Similar to monkey_patch_variable. The difference is, in dygraph mode, use auto-generated op functions for better performance. """ + global paddle def astype(self: Tensor, dtype: DTypeLike) -> Tensor: """ @@ -217,6 +219,7 @@ def _mT_(var: Tensor) -> Tensor: def _new_full_( var: Tensor, + size: ShapeLike, fill_value: bool | float | paddle.Tensor, *, dtype: DTypeLike | None = None, @@ -229,8 +232,8 @@ def _new_full_( device = var.device return paddle.full( - var.shape, - fill_value=fill_value, + size, + fill_value, dtype=dtype, device=device, requires_grad=requires_grad, @@ -238,6 +241,7 @@ def _new_full_( def _new_empty_( var: Tensor, + size: ShapeLike, *, dtype: DTypeLike | None = None, device: PlaceLike | None = None, @@ -249,11 +253,15 @@ def _new_empty_( device = var.device return paddle.empty( - var.shape, dtype=dtype, device=device, requires_grad=requires_grad + size, + dtype, + device=device, + requires_grad=requires_grad, ) def _new_ones_( var: Tensor, + size: ShapeLike, *, dtype: DTypeLike | None = None, device: PlaceLike | None = None, @@ -265,15 +273,16 @@ def _new_ones_( device = var.device return paddle.full( - var.shape, - fill_value=1, - dtype=dtype, + size, + 1, + dtype, device=device, requires_grad=requires_grad, ) def _new_zeros_( var: Tensor, + size: ShapeLike, *, dtype: DTypeLike | None = None, device: PlaceLike | None = None, @@ -285,9 +294,9 @@ def _new_zeros_( device = var.device return paddle.full( - var.shape, - fill_value=0, - dtype=dtype, + size, + 0, + dtype, device=device, requires_grad=requires_grad, ) diff --git a/python/paddle/pir/math_op_patch.py b/python/paddle/pir/math_op_patch.py index 401ebbffc66fc9..a8a56079fa52ca 100644 --- a/python/paddle/pir/math_op_patch.py +++ b/python/paddle/pir/math_op_patch.py @@ -29,7 +29,7 @@ from . import Value if TYPE_CHECKING: - from paddle._typing import DTypeLike, PlaceLike + from paddle._typing import DTypeLike, PlaceLike, ShapeLike _already_patch_value = False @@ -574,6 +574,7 @@ def _mT_(self): def _new_full_( self, + size: ShapeLike, fill_value: bool | float | paddle.Tensor, *, dtype: DTypeLike | None = None, @@ -609,8 +610,8 @@ def _new_full_( device = self.device return paddle.full( - self.shape, - fill_value=fill_value, + size, + fill_value, dtype=dtype, device=device, requires_grad=requires_grad, @@ -618,7 +619,7 @@ def _new_full_( def _new_empty_( self, - fill_value: bool | float | paddle.Tensor, + size: ShapeLike, *, dtype: DTypeLike | None = None, device: PlaceLike | None = None, @@ -653,12 +654,12 @@ def _new_empty_( device = self.device return paddle.empty( - self.shape, dtype=dtype, device=device, requires_grad=requires_grad + size, dtype=dtype, device=device, requires_grad=requires_grad ) def _new_ones_( self, - fill_value: bool | float | paddle.Tensor, + size: ShapeLike, *, dtype: DTypeLike | None = None, device: PlaceLike | None = None, @@ -693,8 +694,8 @@ def _new_ones_( device = self.device return paddle.full( - self.shape, - fill_value=1, + size, + 1, dtype=dtype, device=device, requires_grad=requires_grad, @@ -702,7 +703,7 @@ def _new_ones_( def _new_zeros_( self, - fill_value: bool | float | paddle.Tensor, + size: ShapeLike, *, dtype: DTypeLike | None = None, device: PlaceLike | None = None, @@ -737,8 +738,8 @@ def _new_zeros_( device = self.device return paddle.full( - self.shape, - fill_value=0, + size, + 0, dtype=dtype, device=device, requires_grad=requires_grad, diff --git a/test/legacy_test/test_creation.py b/test/legacy_test/test_creation.py index 965fe145aa8a7f..001868a18c682e 100644 --- a/test/legacy_test/test_creation.py +++ b/test/legacy_test/test_creation.py @@ -15,6 +15,7 @@ import unittest from itertools import product +import numpy as np from utils import dygraph_guard import paddle @@ -303,5 +304,200 @@ def test_empty_like(self): self.assertEqual(x.dtype, dtype) +class TestTensorKPatchMethod(unittest.TestCase): + def setUp(self): + self.devices = [paddle.CPUPlace(), "cpu"] + if paddle.device.is_compiled_with_cuda(): + self.devices.append(paddle.CUDAPlace(0)) + self.devices.append("gpu") + self.devices.append("gpu:0") + if paddle.device.is_compiled_with_xpu(): + self.devices.append(paddle.device.XPUPlace(0)) + if paddle.device.is_compiled_with_ipu(): + self.devices.append(paddle.device.IPUPlace()) + + self.requires_grads = [True, False] + self.shapes = [ + [4, 4], + ] + self.dtypes = ["float32", paddle.float32, "int32", paddle.int32] + + def test_Tensor_new_ones(self): + for shape, device, requires_grad, dtype in product( + self.shapes, self.devices, self.requires_grads, self.dtypes + ): + with dygraph_guard(): + x = paddle.ones( + [1], + ).new_ones( + shape, + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + if isinstance(device, paddle.framework.core.Place): + self.assertEqual(x.place, device) + self.assertEqual(x.stop_gradient, not requires_grad) + if isinstance(dtype, paddle.dtype): + self.assertEqual(x.dtype, dtype) + + def new_ones(x, shape, dtype, requires_grad, device): + return x.new_ones( + shape, + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + + st_f = paddle.jit.to_static( + new_ones, full_graph=True, backend=None + ) + x = st_f( + paddle.randn([1]), + shape, + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + self.assertEqual(x.stop_gradient, not requires_grad) + if isinstance(dtype, paddle.dtype): + self.assertEqual(x.dtype, dtype) + + def test_Tensor_new_zeros(self): + for shape, device, requires_grad, dtype in product( + self.shapes, self.devices, self.requires_grads, self.dtypes + ): + with dygraph_guard(): + x = paddle.zeros( + [1], + ).new_zeros( + shape, + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + if isinstance(device, paddle.framework.core.Place): + self.assertEqual(x.place, device) + self.assertEqual(x.stop_gradient, not requires_grad) + if isinstance(dtype, paddle.dtype): + self.assertEqual(x.dtype, dtype) + + def new_zeros(x, shape, dtype, requires_grad, device): + return x.new_zeros( + shape, + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + + st_f = paddle.jit.to_static( + new_zeros, full_graph=True, backend=None + ) + x = st_f( + paddle.randn([1]), + shape, + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + self.assertEqual(x.stop_gradient, not requires_grad) + if isinstance(dtype, paddle.dtype): + self.assertEqual(x.dtype, dtype) + + def test_Tensor_new_full(self): + for shape, device, requires_grad, dtype in product( + self.shapes, self.devices, self.requires_grads, self.dtypes + ): + with dygraph_guard(): + x = paddle.full( + [1], + 3.14, + ).new_full( + shape, + 2.0, + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + if isinstance(device, paddle.framework.core.Place): + self.assertEqual(x.place, device) + self.assertEqual(x.stop_gradient, not requires_grad) + if isinstance(dtype, paddle.dtype): + self.assertEqual(x.dtype, dtype) + np.testing.assert_allclose( + x.numpy(), paddle.full(shape, 2.0).numpy(), 1e-6, 1e-6 + ) + + def new_full( + x, shape, fill_value, dtype, requires_grad, device + ): + return x.new_full( + shape, + fill_value, + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + + st_f = paddle.jit.to_static( + new_full, full_graph=True, backend=None + ) + x = st_f( + paddle.randn([1]), + shape, + 2.0, + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + self.assertEqual(x.stop_gradient, not requires_grad) + if isinstance(dtype, paddle.dtype): + self.assertEqual(x.dtype, dtype) + np.testing.assert_allclose( + x.numpy(), paddle.full(shape, 2.0).numpy(), 1e-6, 1e-6 + ) + + def test_Tensor_new_empty(self): + for shape, device, requires_grad, dtype in product( + self.shapes, self.devices, self.requires_grads, self.dtypes + ): + with dygraph_guard(): + x = paddle.empty( + [1], + ).new_empty( + shape, + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + if isinstance(device, paddle.framework.core.Place): + self.assertEqual(x.place, device) + self.assertEqual(x.stop_gradient, not requires_grad) + if isinstance(dtype, paddle.dtype): + self.assertEqual(x.dtype, dtype) + + def new_empty(x, shape, dtype, requires_grad, device): + return x.new_empty( + shape, + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + + st_f = paddle.jit.to_static( + new_empty, full_graph=True, backend=None + ) + x = st_f( + paddle.randn([1]), + shape, + dtype=dtype, + requires_grad=requires_grad, + device=device, + ) + self.assertEqual(x.stop_gradient, not requires_grad) + if isinstance(dtype, paddle.dtype): + self.assertEqual(x.dtype, dtype) + + if __name__ == '__main__': unittest.main() From 51f4d6562aacdc705b3ce5a1f79b537b75670c40 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Mon, 11 Aug 2025 13:59:13 +0800 Subject: [PATCH 26/31] update code --- python/paddle/pir/math_op_patch.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/paddle/pir/math_op_patch.py b/python/paddle/pir/math_op_patch.py index e9febd2f05e3ba..e97d1339f851ad 100644 --- a/python/paddle/pir/math_op_patch.py +++ b/python/paddle/pir/math_op_patch.py @@ -671,7 +671,7 @@ def _new_full_( if dtype is None: dtype = self.dtype if device is None: - device = self.device + device = self.place return paddle.full( size, @@ -715,7 +715,7 @@ def _new_empty_( if dtype is None: dtype = self.dtype if device is None: - device = self.device + device = self.place return paddle.empty( size, dtype=dtype, device=device, requires_grad=requires_grad @@ -755,7 +755,7 @@ def _new_ones_( if dtype is None: dtype = self.dtype if device is None: - device = self.device + device = self.place return paddle.full( size, @@ -799,7 +799,7 @@ def _new_zeros_( if dtype is None: dtype = self.dtype if device is None: - device = self.device + device = self.place return paddle.full( size, From bbc4505da8f8e9dde8afbcf4103b66f1e1d4b2c5 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Mon, 11 Aug 2025 14:01:20 +0800 Subject: [PATCH 27/31] add more UT --- test/legacy_test/test_creation.py | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/test/legacy_test/test_creation.py b/test/legacy_test/test_creation.py index 001868a18c682e..afaae182e9d39c 100644 --- a/test/legacy_test/test_creation.py +++ b/test/legacy_test/test_creation.py @@ -61,6 +61,8 @@ def test_ones(self): requires_grad=requires_grad, device=device, ) + if isinstance(device, paddle.framework.core.Place): + self.assertEqual(x.place, device) self.assertEqual(x.stop_gradient, not requires_grad) if isinstance(dtype, paddle.dtype): self.assertEqual(x.dtype, dtype) @@ -90,6 +92,8 @@ def test_zeros(self): requires_grad=requires_grad, device=device, ) + if isinstance(device, paddle.framework.core.Place): + self.assertEqual(x.place, device) self.assertEqual(x.stop_gradient, not requires_grad) if isinstance(dtype, paddle.dtype): self.assertEqual(x.dtype, dtype) @@ -121,6 +125,8 @@ def test_full(self): requires_grad=requires_grad, device=device, ) + if isinstance(device, paddle.framework.core.Place): + self.assertEqual(x.place, device) self.assertEqual(x.stop_gradient, not requires_grad) if isinstance(dtype, paddle.dtype): self.assertEqual(x.dtype, dtype) @@ -150,6 +156,8 @@ def test_empty(self): requires_grad=requires_grad, device=device, ) + if isinstance(device, paddle.framework.core.Place): + self.assertEqual(x.place, device) self.assertEqual(x.stop_gradient, not requires_grad) if isinstance(dtype, paddle.dtype): self.assertEqual(x.dtype, dtype) @@ -181,6 +189,8 @@ def test_eye(self): requires_grad=requires_grad, device=device, ) + if isinstance(device, paddle.framework.core.Place): + self.assertEqual(x.place, device) self.assertEqual(x.stop_gradient, not requires_grad) if isinstance(dtype, paddle.dtype): self.assertEqual(x.dtype, dtype) @@ -210,6 +220,8 @@ def test_ones_like(self): requires_grad=requires_grad, device=device, ) + if isinstance(device, paddle.framework.core.Place): + self.assertEqual(x.place, device) self.assertEqual(x.stop_gradient, not requires_grad) if isinstance(dtype, paddle.dtype): self.assertEqual(x.dtype, dtype) @@ -239,6 +251,8 @@ def test_zeros_like(self): requires_grad=requires_grad, device=device, ) + if isinstance(device, paddle.framework.core.Place): + self.assertEqual(x.place, device) self.assertEqual(x.stop_gradient, not requires_grad) if isinstance(dtype, paddle.dtype): self.assertEqual(x.dtype, dtype) @@ -270,6 +284,8 @@ def test_full_like(self): requires_grad=requires_grad, device=device, ) + if isinstance(device, paddle.framework.core.Place): + self.assertEqual(x.place, device) self.assertEqual(x.stop_gradient, not requires_grad) if isinstance(dtype, paddle.dtype): self.assertEqual(x.dtype, dtype) @@ -299,6 +315,8 @@ def test_empty_like(self): requires_grad=requires_grad, device=device, ) + if isinstance(device, paddle.framework.core.Place): + self.assertEqual(x.place, device) self.assertEqual(x.stop_gradient, not requires_grad) if isinstance(dtype, paddle.dtype): self.assertEqual(x.dtype, dtype) @@ -306,7 +324,7 @@ def test_empty_like(self): class TestTensorKPatchMethod(unittest.TestCase): def setUp(self): - self.devices = [paddle.CPUPlace(), "cpu"] + self.devices = [None, paddle.CPUPlace(), "cpu"] if paddle.device.is_compiled_with_cuda(): self.devices.append(paddle.CUDAPlace(0)) self.devices.append("gpu") @@ -359,6 +377,8 @@ def new_ones(x, shape, dtype, requires_grad, device): requires_grad=requires_grad, device=device, ) + if isinstance(device, paddle.framework.core.Place): + self.assertEqual(x.place, device) self.assertEqual(x.stop_gradient, not requires_grad) if isinstance(dtype, paddle.dtype): self.assertEqual(x.dtype, dtype) @@ -400,6 +420,8 @@ def new_zeros(x, shape, dtype, requires_grad, device): requires_grad=requires_grad, device=device, ) + if isinstance(device, paddle.framework.core.Place): + self.assertEqual(x.place, device) self.assertEqual(x.stop_gradient, not requires_grad) if isinstance(dtype, paddle.dtype): self.assertEqual(x.dtype, dtype) @@ -450,6 +472,8 @@ def new_full( requires_grad=requires_grad, device=device, ) + if isinstance(device, paddle.framework.core.Place): + self.assertEqual(x.place, device) self.assertEqual(x.stop_gradient, not requires_grad) if isinstance(dtype, paddle.dtype): self.assertEqual(x.dtype, dtype) @@ -494,6 +518,8 @@ def new_empty(x, shape, dtype, requires_grad, device): requires_grad=requires_grad, device=device, ) + if isinstance(device, paddle.framework.core.Place): + self.assertEqual(x.place, device) self.assertEqual(x.stop_gradient, not requires_grad) if isinstance(dtype, paddle.dtype): self.assertEqual(x.dtype, dtype) From 6447ac9a9bd697f157a195814aa82f362348b966 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Tue, 12 Aug 2025 11:23:50 +0800 Subject: [PATCH 28/31] fix --- python/paddle/base/dygraph/math_op_patch.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/paddle/base/dygraph/math_op_patch.py b/python/paddle/base/dygraph/math_op_patch.py index db1ade3df91515..b2b81831565807 100644 --- a/python/paddle/base/dygraph/math_op_patch.py +++ b/python/paddle/base/dygraph/math_op_patch.py @@ -297,7 +297,7 @@ def _new_full_( if dtype is None: dtype = var.dtype if device is None: - device = var.device + device = var.place return paddle.full( size, @@ -318,7 +318,7 @@ def _new_empty_( if dtype is None: dtype = var.dtype if device is None: - device = var.device + device = var.place return paddle.empty( size, @@ -338,7 +338,7 @@ def _new_ones_( if dtype is None: dtype = var.dtype if device is None: - device = var.device + device = var.place return paddle.full( size, @@ -359,7 +359,7 @@ def _new_zeros_( if dtype is None: dtype = var.dtype if device is None: - device = var.device + device = var.place return paddle.full( size, From a118f49b5fdf88e69f382ecf629c44bfac17f89b Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Tue, 12 Aug 2025 14:35:17 +0800 Subject: [PATCH 29/31] fix UT --- python/paddle/tensor/creation.py | 7 ++++++- test/legacy_test/test_math_op_patch_pir.py | 12 ++++++------ 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index adaa1a54989d58..1e08c090daa780 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -1096,7 +1096,12 @@ def full_like( x, fill_value, dtype, _get_paddle_place(device) ) else: - tensor = _C_ops.full_like(x, fill_value, dtype, core.Place()) + tensor = _C_ops.full_like( + x, + fill_value, + dtype, + core.Place() if device is None else _get_paddle_place(device), + ) if requires_grad is True: tensor.stop_gradient = False return tensor diff --git a/test/legacy_test/test_math_op_patch_pir.py b/test/legacy_test/test_math_op_patch_pir.py index 679392969aee32..2ea1b8798f179a 100644 --- a/test/legacy_test/test_math_op_patch_pir.py +++ b/test/legacy_test/test_math_op_patch_pir.py @@ -740,9 +740,9 @@ def test_new_xxx(self): with program_guard: x = paddle.rand(shape, dtype="float32") x_new = x.new_full([7], 1.0) - self.assertEqual(x_new.shape, out_shape) + self.assertEqual(x_new.shape, [7]) (output_x,) = exe.run(main_program, fetch_list=[x_new]) - self.assertEqual(output_x.shape, (7)) + self.assertEqual(output_x.shape, (7,)) shape = [1, 2, 3, 0, 1] out_shape = list(shape) @@ -751,7 +751,7 @@ def test_new_xxx(self): with program_guard: x = paddle.rand(shape, dtype="float32") x_new = x.new_full([3, 0], 4.0) - self.assertEqual(x_new.shape, out_shape) + self.assertEqual(x_new.shape, [3, 0]) (output_x,) = exe.run(main_program, fetch_list=[x_new]) self.assertEqual(output_x.shape, (3, 0)) @@ -762,7 +762,7 @@ def test_new_xxx(self): with program_guard: x = paddle.rand(shape, dtype="float32") x_new = x.new_empty([2, 2]) - self.assertEqual(x_new.shape, out_shape) + self.assertEqual(x_new.shape, [2, 2]) (output_x,) = exe.run(main_program, fetch_list=[x_new]) self.assertEqual(output_x.shape, (2, 2)) @@ -773,7 +773,7 @@ def test_new_xxx(self): with program_guard: x = paddle.rand(shape, dtype="float32") x_new = x.new_ones([2, 2]) - self.assertEqual(x_new.shape, out_shape) + self.assertEqual(x_new.shape, [2, 2]) (output_x,) = exe.run(main_program, fetch_list=[x_new]) self.assertEqual(output_x.shape, (2, 2)) @@ -784,7 +784,7 @@ def test_new_xxx(self): with program_guard: x = paddle.rand(shape, dtype="float32") x_new = x.new_zeros([2, 3]) - self.assertEqual(x_new.shape, out_shape) + self.assertEqual(x_new.shape, [2, 3]) (output_x,) = exe.run(main_program, fetch_list=[x_new]) self.assertEqual(output_x.shape, (2, 3)) From 22d507e79e70e41017dc5a86c2757e111fbe1899 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Tue, 12 Aug 2025 15:21:49 +0800 Subject: [PATCH 30/31] update docstring --- python/paddle/pir/math_op_patch.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/python/paddle/pir/math_op_patch.py b/python/paddle/pir/math_op_patch.py index e97d1339f851ad..a529fa04612057 100644 --- a/python/paddle/pir/math_op_patch.py +++ b/python/paddle/pir/math_op_patch.py @@ -647,7 +647,7 @@ def _new_full_( ): """ - Returns a Tensor of size size filled with fill_value. + Returns a Tensor of size ``size`` filled with ``fill_value``. By default, the returned Tensor has the same dtype and place as this tensor. Examples: @@ -657,7 +657,7 @@ def _new_full_( >>> paddle.enable_static() >>> x = paddle.ones(shape=[2, 3, 5]) - >>> x_new = x.new_full(3.14, dtype="float64", device="cpu") + >>> x_new = x.new_full([2, 3], 3.14, dtype="float64", device="cpu") >>> exe = paddle.static.Executor() >>> x_new_np = exe.run(paddle.static.default_main_program(), fetch_list=[x_new])[0] @@ -691,7 +691,7 @@ def _new_empty_( ): """ - Returns a Tensor of size size filled with fill_value. + Returns a Tensor of size ``size`` filled with uninitialized data. By default, the returned Tensor has the same dtype and place as this tensor. Examples: @@ -701,12 +701,12 @@ def _new_empty_( >>> paddle.enable_static() >>> x = paddle.ones(shape=[2, 3, 5]) - >>> x_new = x.new_empty(dtype="float64", device="cpu") + >>> x_new = x.new_empty([2, 3], dtype="float64", device="cpu") >>> exe = paddle.static.Executor() >>> x_new_np = exe.run(paddle.static.default_main_program(), fetch_list=[x_new])[0] >>> print(x_new_np.shape) - (2, 5, 3) + (2, 3) >>> print(str(x_new_np.dtype)) 'paddle.float64' >>> print(x_new_np.place) @@ -731,7 +731,7 @@ def _new_ones_( ): """ - Returns a Tensor of size size filled with fill_value. + Returns a Tensor of size ``size`` filled with ``1``. By default, the returned Tensor has the same dtype and place as this tensor. Examples: @@ -741,12 +741,12 @@ def _new_ones_( >>> paddle.enable_static() >>> x = paddle.ones(shape=[2, 3, 5]) - >>> x_new = x.new_ones(3.14, dtype="float64", device="cpu") + >>> x_new = x.new_ones([2, 3], dtype="float64", device="cpu") >>> exe = paddle.static.Executor() >>> x_new_np = exe.run(paddle.static.default_main_program(), fetch_list=[x_new])[0] >>> print(x_new_np.shape) - (2, 5, 3) + (2, 3) >>> print(str(x_new_np.dtype)) 'paddle.float64' >>> print(x_new_np.place) @@ -775,7 +775,7 @@ def _new_zeros_( ): """ - Returns a Tensor of size size filled with fill_value. + Returns a Tensor of size ``size`` filled with ``0``. By default, the returned Tensor has the same dtype and place as this tensor. Examples: @@ -785,12 +785,12 @@ def _new_zeros_( >>> paddle.enable_static() >>> x = paddle.ones(shape=[2, 3, 5]) - >>> x_new = x.new_zeros(3.14, dtype="float64", device="cpu") + >>> x_new = x.new_zeros([2, 3], dtype="float64", device="cpu") >>> exe = paddle.static.Executor() >>> x_new_np = exe.run(paddle.static.default_main_program(), fetch_list=[x_new])[0] >>> print(x_new_np.shape) - (2, 5, 3) + (2, 3) >>> print(str(x_new_np.dtype)) 'paddle.float64' >>> print(x_new_np.place) From f1da718ba57be2b3fd144ba107f3b0d9f3801433 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Tue, 12 Aug 2025 17:54:41 +0800 Subject: [PATCH 31/31] skip xpu test --- test/legacy_test/test_creation.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/test/legacy_test/test_creation.py b/test/legacy_test/test_creation.py index afaae182e9d39c..243be8366f1a4e 100644 --- a/test/legacy_test/test_creation.py +++ b/test/legacy_test/test_creation.py @@ -189,7 +189,11 @@ def test_eye(self): requires_grad=requires_grad, device=device, ) - if isinstance(device, paddle.framework.core.Place): + if ( + isinstance(device, paddle.framework.core.Place) + # skip xpu for unknown reason + and not isinstance(device, paddle.framework.core.XPUPlace) + ): self.assertEqual(x.place, device) self.assertEqual(x.stop_gradient, not requires_grad) if isinstance(dtype, paddle.dtype): @@ -322,7 +326,7 @@ def test_empty_like(self): self.assertEqual(x.dtype, dtype) -class TestTensorKPatchMethod(unittest.TestCase): +class TestTensorPatchMethod(unittest.TestCase): def setUp(self): self.devices = [None, paddle.CPUPlace(), "cpu"] if paddle.device.is_compiled_with_cuda():