diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index b68f8e48df26d7..00eea73f6665dc 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -1381,6 +1381,7 @@ def ones_like( ) +@SizeArgsDecorator() def zeros( shape: ShapeLike, dtype: DTypeLike | None = None, diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index ec94963095696b..79276301a5c610 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -27,6 +27,7 @@ from paddle.utils.decorator_utils import ( ParamAliasDecorator, param_one_alias, + view_decorator, ) from paddle.utils.inplace_utils import inplace_apis_in_dygraph_only @@ -7334,6 +7335,7 @@ def as_strided( @dygraph_only +@view_decorator() def view( x: Tensor, shape_or_dtype: Sequence[int] | DTypeLike, diff --git a/python/paddle/utils/decorator_utils.py b/python/paddle/utils/decorator_utils.py index 831f1e73313cec..35152f365f2125 100644 --- a/python/paddle/utils/decorator_utils.py +++ b/python/paddle/utils/decorator_utils.py @@ -131,3 +131,37 @@ def process( args = () return args, kwargs + + +""" + Usage Example: + paddle.view(x=tensor_x, shape_or_dtype=[-1, 1, 3], name=None) + + tensor_x.view(paddle.float32) -> paddle.view(tensor_x, paddle.float32) + tensor_x.view(dtype=paddle.float32) -> paddle.view(tensor_x, dtype=paddle.float32) + + tensor_x.view([-1, 1, 3]) -> paddle.view(tensor_x, [-1, 1, 3]) + tensor_x.view(-1, 1, 3) -> paddle.view(tensor_x, -1, 1, 3) + tensor_x.view(size=[-1, 1, 3]) -> paddle.view(tensor_x, size=[-1, 1, 3]) +""" + + +def view_decorator(): + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if ("dtype" in kwargs) and ("shape_or_dtype" not in kwargs): + kwargs["shape_or_dtype"] = kwargs.pop("dtype") + elif ("size" in kwargs) and ("shape_or_dtype" not in kwargs): + kwargs["shape_or_dtype"] = kwargs.pop("size") + elif len(args) >= 2 and type(args[1]) is int: + if all(type(arg) is int for arg in args[1:]): + kwargs["x"] = args[0] + kwargs['shape_or_dtype'] = list(args[1:]) + args = () + return func(*args, **kwargs) + + wrapper.__signature__ = inspect.signature(func) + return wrapper + + return decorator diff --git a/test/legacy_test/test_stride.py b/test/legacy_test/test_stride.py index c6f8a6f315faba..db52416e887722 100644 --- a/test/legacy_test/test_stride.py +++ b/test/legacy_test/test_stride.py @@ -890,6 +890,60 @@ def call_view16(self): self.assertTrue(out_c._is_shared_buffer_with(out)) + def call_view_alias1(self): + x_np = np.random.random(size=[10, 10, 10, 20]).astype('float32') + x = paddle.to_tensor(x_np) + np.testing.assert_allclose(x.numpy(), x_np) + + np_out = x_np.reshape(10, 100, 20) + + out1 = x.view([10, 100, 20]) + np.testing.assert_allclose(out1.numpy(), np_out) + self.assertTrue(out1.is_contiguous()) + self.assertTrue(x._is_shared_buffer_with(out1)) + out_c1 = out1.contiguous() + np.testing.assert_allclose(out_c1.numpy(), np_out) + self.assertTrue(out_c1._is_shared_buffer_with(out1)) + + out2 = x.view(10, 100, 20) + np.testing.assert_allclose(out2.numpy(), np_out) + self.assertTrue(out2.is_contiguous()) + self.assertTrue(x._is_shared_buffer_with(out2)) + out_c2 = out2.contiguous() + np.testing.assert_allclose(out_c2.numpy(), np_out) + self.assertTrue(out_c2._is_shared_buffer_with(out2)) + + out3 = x.view(size=[10, 100, 20]) + np.testing.assert_allclose(out3.numpy(), np_out) + self.assertTrue(out3.is_contiguous()) + self.assertTrue(x._is_shared_buffer_with(out3)) + out_c1 = out3.contiguous() + np.testing.assert_allclose(out_c1.numpy(), np_out) + self.assertTrue(out_c1._is_shared_buffer_with(out3)) + + def call_view_alias2(self): + x_np = np.random.random(size=[10, 10, 10, 20]).astype('float32') + x = paddle.to_tensor(x_np) + np.testing.assert_allclose(x.numpy(), x_np) + + np_out = x_np.view(np.uint8) + + out1 = paddle.view(x, dtype="uint8") + np.testing.assert_allclose(out1.numpy(), np_out) + self.assertTrue(out1.is_contiguous()) + self.assertTrue(x._is_shared_buffer_with(out1)) + out_c1 = out1.contiguous() + np.testing.assert_allclose(out_c1.numpy(), np_out) + self.assertTrue(out_c1._is_shared_buffer_with(out1)) + + out2 = x.view(dtype="uint8") + np.testing.assert_allclose(out2.numpy(), np_out) + self.assertTrue(out2.is_contiguous()) + self.assertTrue(x._is_shared_buffer_with(out2)) + out_c1 = out2.contiguous() + np.testing.assert_allclose(out_c1.numpy(), np_out) + self.assertTrue(out_c1._is_shared_buffer_with(out2)) + def call_stride(self): self.call_transpose() self.call_diagonal() @@ -926,6 +980,8 @@ def call_stride(self): self.call_view14() self.call_view15() self.call_view16() + self.call_view_alias1() + self.call_view_alias2() self.call_view_as() self.call_unfold() diff --git a/test/legacy_test/test_zeros_op.py b/test/legacy_test/test_zeros_op.py index fa5529e66df992..60ef6bf74ad894 100644 --- a/test/legacy_test/test_zeros_op.py +++ b/test/legacy_test/test_zeros_op.py @@ -23,6 +23,7 @@ class ApiZerosTest(unittest.TestCase): def test_out(self): + paddle.enable_static() with program_guard(Program()): zeros = paddle.zeros(shape=[10], dtype='float64') place = paddle.CPUPlace() @@ -58,6 +59,7 @@ def test_out(self): exe = paddle.static.Executor(place) result = exe.run(fetch_list=[out]) self.assertEqual((result == out_np).all(), True) + paddle.disable_static() class ApiZerosError(unittest.TestCase): @@ -79,5 +81,67 @@ def test_dynamic_shape(self): self.assertEqual(out.shape, [101, -1]) +class ZerosAliasTest(unittest.TestCase): + def test_out(self): + paddle.enable_static() + with program_guard(Program()): + zeros = paddle.zeros(3, 3, dtype='float64') + place = paddle.CPUPlace() + exe = paddle.static.Executor(place) + (result,) = exe.run(fetch_list=[zeros]) + expected_result = np.zeros((3, 3), dtype='float64') + self.assertEqual((result == expected_result).all(), True) + + with program_guard(Program()): + zeros = paddle.zeros((3, 3), dtype='float64') + place = paddle.CPUPlace() + exe = paddle.static.Executor(place) + (result,) = exe.run(fetch_list=[zeros]) + expected_result = np.zeros((3, 3), dtype='float64') + self.assertEqual((result == expected_result).all(), True) + + with program_guard(Program()): + zeros = paddle.zeros([3, 3], dtype='float64') + place = paddle.CPUPlace() + exe = paddle.static.Executor(place) + (result,) = exe.run(fetch_list=[zeros]) + expected_result = np.zeros((3, 3), dtype='float64') + self.assertEqual((result == expected_result).all(), True) + + with program_guard(Program()): + zeros = paddle.zeros(size=(3, 3), dtype='float64') + place = paddle.CPUPlace() + exe = paddle.static.Executor(place) + (result,) = exe.run(fetch_list=[zeros]) + expected_result = np.zeros((3, 3), dtype='float64') + self.assertEqual((result == expected_result).all(), True) + paddle.disable_static() + + def test_dygraph_ones(self): + paddle.disable_static() + result = paddle.zeros(10, dtype=paddle.float32) + expect = np.zeros([10], dtype="float32") + np.testing.assert_equal(result, expect) + + result = paddle.zeros(10, 2, 3, dtype=paddle.float32) + expect = np.zeros([10, 2, 3], dtype="float32") + np.testing.assert_equal(result, expect) + + result = paddle.zeros([10, 2, 3], dtype=paddle.float32) + np.testing.assert_equal(result, expect) + + result = paddle.zeros(size=[10, 2, 3], dtype=paddle.float32) + np.testing.assert_equal(result, expect) + + result = paddle.zeros([10, 2, 3], paddle.float32) + np.testing.assert_equal(result, expect) + + result = paddle.zeros([10, 2, 3], "float32") + np.testing.assert_equal(result, expect) + + result = paddle.zeros(shape=[10, 2, 3], dtype=paddle.float32) + np.testing.assert_equal(result, expect) + + if __name__ == '__main__': unittest.main()