Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@
masked_scatter_,
moveaxis,
put_along_axis,
ravel,
repeat_interleave,
reshape,
reshape_,
Expand Down Expand Up @@ -1090,6 +1091,7 @@
'set_rng_state',
'set_printoptions',
'std',
'ravel',
'flatten',
'flatten_',
'asin',
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@
moveaxis,
put_along_axis,
put_along_axis_,
ravel,
repeat_interleave,
reshape,
reshape_,
Expand Down Expand Up @@ -675,6 +676,7 @@
'expand',
'broadcast_to',
'expand_as',
'ravel',
'flatten',
'flatten_',
'gather',
Expand Down
39 changes: 39 additions & 0 deletions python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1854,6 +1854,45 @@ def rot90(
return flip(transpose(x, axes_list), axes[1])


def ravel(x: Tensor) -> Tensor:
"""
Flattens a tensor across all axes.

Note:
The output Tensor will share data with origin Tensor and doesn't have a Tensor copy in ``dygraph`` mode.
If you want to use the Tensor copy version, please use `Tensor.clone` like ``ravel_clone_x = x.ravel().clone()``.

Args:
x (Tensor): A tensor with data type float16, float32, float64, int8, int32, int64, uint8.

Returns:
Tensor, A tensor with the contents of the input tensor, whose input axes are across all axes, and data type is the same as input :attr:`x`.

Examples:

.. code-block:: python

>>> import paddle

>>> image_shape=(2, 3, 4, 4)

>>> x = paddle.arange(end=image_shape[0] * image_shape[1] * image_shape[2] * image_shape[3])
>>> img = paddle.reshape(x, image_shape)

>>> out = paddle.ravel(img)
>>> print(out.shape)
[96]

>>> # out shares data with img in dygraph mode
>>> img[0, 0, 0, 0] = -1
>>> print(out[0])
Tensor(shape=[], dtype=int64, place=Place(cpu), stop_gradient=True,
-1)
"""

return flatten(x)


def flatten(
x: Tensor, start_axis: int = 0, stop_axis: int = -1, name: str | None = None
) -> Tensor:
Expand Down
237 changes: 237 additions & 0 deletions test/legacy_test/test_ravel_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
from op_test import OpTest, convert_float_to_uint16

import paddle
from paddle.base import core


class TestRavelOp(OpTest):
def setUp(self):
self.python_api = paddle.ravel
self.public_python_api = paddle.ravel
self.python_out_sig = ["Out"]
self.op_type = "flatten_contiguous_range"
self.prim_op_type = "comp"
self.start_axis = 0
self.stop_axis = -1
self.if_enable_cinn()
self.init_test_case()
self.init_test_dtype()
self.init_input_data()
self.init_attrs()
self.outputs = {
"Out": self.inputs["X"].reshape(self.new_shape),
"XShape": np.random.random(self.in_shape).astype("float32"),
}

def if_enable_cinn(self):
pass

def test_check_output(self):
if str(self.dtype) in {"float16", "uint16"}:
self.check_output_with_place(
core.CUDAPlace(0),
no_check_set=["XShape"],
check_prim=True,
check_pir=True,
check_prim_pir=True,
)
else:
self.check_output(
no_check_set=["XShape"],
check_prim=True,
check_pir=True,
check_prim_pir=True,
)

def test_check_grad(self):
if str(self.dtype) in {"float16", "uint16"}:
self.check_grad_with_place(
core.CUDAPlace(0),
["X"],
"Out",
check_prim=True,
check_pir=True,
)
else:
self.check_grad(["X"], "Out", check_prim=True, check_pir=True)

def init_test_case(self):
self.in_shape = (3, 2, 5, 4)
self.start_axis = 0
self.stop_axis = -1
self.new_shape = 120

def init_attrs(self):
self.attrs = {
"start_axis": self.start_axis,
"stop_axis": self.stop_axis,
}

def init_test_dtype(self):
self.dtype = "float64"

def init_input_data(self):
if str(self.dtype) != "uint16":
x = np.random.random(self.in_shape).astype(self.dtype)
else:
x = np.random.random(self.in_shape).astype("float32")
x = convert_float_to_uint16(x)

self.inputs = {"X": x}


class TestRavelFP32Op(TestRavelOp):
def init_test_dtype(self):
self.dtype = "float32"


@unittest.skipIf(
not core.is_compiled_with_cuda(),
"core is not compiled with CUDA",
)
class TestRavelFP16Op(TestRavelOp):
def init_test_dtype(self):
self.dtype = "float16"


@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA and not support the bfloat16",
)
class TestRavelBF16Op(TestRavelOp):
def if_enable_cinn(self):
pass

def init_test_dtype(self):
self.dtype = "uint16"


class TestRavelOp_ZeroDim(TestRavelOp):
def init_test_case(self):
self.in_shape = ()
self.start_axis = 0
self.stop_axis = -1
self.new_shape = (1,)

def if_enable_cinn(self):
self.enable_cinn = False

def init_attrs(self):
self.attrs = {
"start_axis": self.start_axis,
"stop_axis": self.stop_axis,
}


class TestRavelFP32Op_ZeroDim(TestRavelOp_ZeroDim):
def init_test_dtype(self):
self.dtype = "float32"


@unittest.skipIf(
not core.is_compiled_with_cuda(),
"core is not compiled with CUDA",
)
class TestRavelFP16Op_ZeroDim(TestRavelOp_ZeroDim):
def init_test_dtype(self):
self.dtype = "float16"


class TestRavelOpError(unittest.TestCase):
def test_errors(self):
image_shape = (2, 3, 4, 4)
x = (
np.arange(
image_shape[0]
* image_shape[1]
* image_shape[2]
* image_shape[3]
).reshape(image_shape)
/ 100.0
)
x = x.astype('float32')

def test_InputError():
out = paddle.ravel(x)

self.assertRaises(ValueError, test_InputError)


class TestStaticRavelPythonAPI(unittest.TestCase):
def execute_api(self, x):
return paddle.ravel(x)

def test_static_api(self):
paddle.enable_static()
np_x = np.random.rand(2, 3, 4, 4).astype('float32')

main_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, paddle.static.Program()):
x = paddle.static.data(
name="x", shape=[2, 3, 4, 4], dtype='float32'
)
out = self.execute_api(x)

exe = paddle.static.Executor(place=paddle.CPUPlace())
fetch_out = exe.run(main_prog, feed={"x": np_x}, fetch_list=[out])
self.assertTrue((96,) == fetch_out[0].shape)


class TestStaticRavelInferShapePythonAPI(unittest.TestCase):
def execute_api(self, x):
return paddle.ravel(x)

def test_static_api(self):
paddle.enable_static()
main_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, paddle.static.Program()):
x = paddle.static.data(
name="x", shape=[-1, 3, -1, -1], dtype='float32'
)
out = self.execute_api(x)
self.assertTrue((-1,) == tuple(out.shape))


class TestRavelZeroSizedTensorAPI(unittest.TestCase):
def test_dygraph(self):
paddle.disable_static()
data = np.random.randn(2, 3, 0)
x = paddle.to_tensor(data)
out = paddle.ravel(x)
out_np = data.flatten()
np.testing.assert_equal(out.numpy(), out_np)

def test_static(self):
paddle.enable_static()
data = np.random.randn(2, 3, 0)
main_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, paddle.static.Program()):
x = paddle.static.data(name="x", shape=[2, 3, 0], dtype='float64')
out = paddle.ravel(x)

exe = paddle.static.Executor(place=paddle.CPUPlace())
fetch_out = exe.run(main_prog, feed={"x": data}, fetch_list=[out])[0]
out_np = data.flatten()
np.testing.assert_equal(fetch_out, out_np)


if __name__ == "__main__":
unittest.main()
Loading