Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,8 @@
unstack,
view,
view_as,
view_as_complex,
view_as_real,
vsplit,
vstack,
)
Expand Down Expand Up @@ -1159,7 +1161,9 @@
'acosh',
'atanh',
'as_complex',
'view_as_complex',
'as_real',
'view_as_real',
'diff',
'angle',
'fmax',
Expand Down
4 changes: 4 additions & 0 deletions python/paddle/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,8 @@
unstack,
view,
view_as,
view_as_complex,
view_as_real,
vsplit,
vstack,
)
Expand Down Expand Up @@ -779,7 +781,9 @@
'lu_unpack',
'cdist',
'as_complex',
'view_as_complex',
'as_real',
'view_as_real',
'rad2deg',
'deg2rad',
'gcd',
Expand Down
78 changes: 77 additions & 1 deletion python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6289,7 +6289,83 @@ def as_real(x: Tensor, name: str | None = None) -> Tensor:
return out


@ParamAliasDecorator({"x": ["input"], "axis": ["dim"]})
def view_as_complex(input: Tensor) -> Tensor:
"""Return a complex tensor that is a view of the input real tensor .

The data type of the input tensor is 'float32' or 'float64', and the data
type of the returned tensor is 'complex64' or 'complex128', respectively.

The shape of the input tensor is ``(* ,2)``, (``*`` means arbitrary shape), i.e.
the size of the last axis should be 2, which represent the real and imag part
of a complex number. The shape of the returned tensor is ``(*,)``.

The complex tensor is a view of the input real tensor, meaning that it shares the same memory with real tensor.

The image below demonstrates the case that a real 3D-tensor with shape [2, 3, 2] is transformed into a complex 2D-tensor with shape [2, 3].

.. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/images/api_legend/as_complex.png
:width: 500
:alt: Illustration of as_complex
:align: center

Args:
input (Tensor): The input tensor. Data type is 'float32' or 'float64'.

Returns:
Tensor, The output. Data type is 'complex64' or 'complex128', sharing the same memory with input.

Examples:
.. code-block:: python

>>> import paddle
>>> x = paddle.arange(12, dtype=paddle.float32).reshape([2, 3, 2])
>>> y = paddle.as_complex(x)
>>> print(y)
Tensor(shape=[2, 3], dtype=complex64, place=Place(cpu), stop_gradient=True,
[[1j , (2+3j) , (4+5j) ],
[(6+7j) , (8+9j) , (10+11j)]])
"""

return as_complex(x=input)


def view_as_real(input: Tensor) -> Tensor:
"""Return a real tensor that is a view of the input complex tensor.

The data type of the input tensor is 'complex64' or 'complex128', and the data
type of the returned tensor is 'float32' or 'float64', respectively.

When the shape of the input tensor is ``(*, )``, (``*`` means arbitrary shape),
the shape of the output tensor is ``(*, 2)``, i.e. the shape of the output is
the shape of the input appended by an extra ``2``.

The real tensor is a view of the input complex tensor, meaning that it shares the same memory with complex tensor.

Args:
input (Tensor): The input tensor. Data type is 'complex64' or 'complex128'.

Returns:
Tensor, The output. Data type is 'float32' or 'float64', sharing the same memory with input.

Examples:
.. code-block:: python

>>> import paddle
>>> x = paddle.arange(12, dtype=paddle.float32).reshape([2, 3, 2])
>>> y = paddle.as_complex(x)
>>> z = paddle.as_real(y)
>>> print(z)
Tensor(shape=[2, 3, 2], dtype=float32, place=Place(cpu), stop_gradient=True,
[[[0. , 1. ],
[2. , 3. ],
[4. , 5. ]],
[[6. , 7. ],
[8. , 9. ],
[10., 11.]]])
"""
return as_real(x=input)


def repeat_interleave(
x: Tensor,
repeats: int | Tensor,
Expand Down
55 changes: 50 additions & 5 deletions test/legacy_test/test_complex_view_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def ref_view_as_real(x):
return np.stack([x.real, x.imag], -1)


class TestViewAsComplexOp(OpTest):
class TestAsComplexOp(OpTest):
def setUp(self):
self.op_type = "as_complex"
self.python_api = paddle.as_complex
Expand All @@ -53,7 +53,7 @@ def test_check_grad(self):
)


class TestViewAsRealOp(OpTest):
class TestAsRealOp(OpTest):
def setUp(self):
self.op_type = "as_real"
real = np.random.randn(10, 10).astype("float64")
Expand All @@ -75,7 +75,7 @@ def test_check_grad(self):
)


class TestViewAsComplexAPI(unittest.TestCase):
class TestAsComplexAPI(unittest.TestCase):
def setUp(self):
self.x = np.random.randn(10, 10, 2)
self.out = ref_view_as_complex(self.x)
Expand All @@ -98,7 +98,7 @@ def test_static(self):
np.testing.assert_allclose(self.out, out_np, rtol=1e-05)


class TestViewAsRealAPI(unittest.TestCase):
class TestAsRealAPI(unittest.TestCase):
def setUp(self):
self.x = np.random.randn(10, 10) + 1j * np.random.randn(10, 10)
self.out = ref_view_as_real(self.x)
Expand All @@ -121,7 +121,7 @@ def test_static(self):
np.testing.assert_allclose(self.out, out_np, rtol=1e-05)


class TestViewAsRealAPI_ZeroSize(unittest.TestCase):
class TestAsRealAPI_ZeroSize(unittest.TestCase):
def setUp(self):
self.x = np.random.randn(10, 0) + 1j * np.random.randn(10, 0)
self.out = ref_view_as_real(self.x)
Expand All @@ -137,5 +137,50 @@ def test_dygraph(self):
np.testing.assert_allclose(x_tensor.grad.shape, x_tensor.shape)


class TestViewAsComplexAPI(unittest.TestCase):
def setUp(self):
self.x = np.random.randn(10, 10, 2)
self.out = ref_view_as_complex(self.x)

def test_dygraph(self):
with dygraph.guard():
x = paddle.to_tensor(self.x)
out = paddle.view_as_complex(x)
out_np = out.numpy()
self.assertEqual(out.data_ptr(), x.data_ptr())
np.testing.assert_allclose(self.out, out_np, rtol=1e-05)


class TestViewAsRealAPI(unittest.TestCase):
def setUp(self):
self.x = np.random.randn(10, 10) + 1j * np.random.randn(10, 10)
self.out = ref_view_as_real(self.x)

def test_dygraph(self):
with dygraph.guard():
x = paddle.to_tensor(self.x)
out = paddle.view_as_real(x)
out_np = out.numpy()
self.assertEqual(out.data_ptr(), x.data_ptr())
np.testing.assert_allclose(self.out, out_np, rtol=1e-05)


class TestViewAsRealAPI_ZeroSize(unittest.TestCase):
def setUp(self):
self.x = np.random.randn(10, 0) + 1j * np.random.randn(10, 0)
self.out = ref_view_as_real(self.x)

def test_dygraph(self):
for place in get_places():
with dygraph.guard(place):
x_tensor = paddle.to_tensor(self.x)
x_tensor.stop_gradient = False
out = paddle.view_as_real(x_tensor)
np.testing.assert_allclose(self.out, out.numpy(), rtol=1e-05)
self.assertEqual(out.data_ptr(), x_tensor.data_ptr())
out.sum().backward()
np.testing.assert_allclose(x_tensor.grad.shape, x_tensor.shape)


if __name__ == "__main__":
unittest.main()