diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index a9e3345474f4c8..86aa1c6fe0de25 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -266,6 +266,7 @@ atleast_1d, atleast_2d, atleast_3d, + block_diag, broadcast_tensors, broadcast_to, cast, @@ -610,6 +611,7 @@ ir_guard._switch_to_pir() __all__ = [ + 'block_diag', 'iinfo', 'finfo', 'dtype', diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index 7818c2398494d2..6a7cf78d55b024 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -69,7 +69,9 @@ Upsample, UpsamplingBilinear2D, UpsamplingNearest2D, + ZeroPad1D, ZeroPad2D, + ZeroPad3D, ) # TODO: import all neural network related api under this directory, @@ -298,4 +300,6 @@ 'Unflatten', 'FractionalMaxPool2D', 'FractionalMaxPool3D', + 'ZeroPad1D', + 'ZeroPad3D', ] diff --git a/python/paddle/nn/layer/common.py b/python/paddle/nn/layer/common.py index 6faf07bb6eb19d..6b34c9fa90f6bf 100644 --- a/python/paddle/nn/layer/common.py +++ b/python/paddle/nn/layer/common.py @@ -1077,6 +1077,67 @@ def extra_repr(self): return f'padding={self._pad}, mode={self._mode}, value={self._value}, data_format={self._data_format}{name_str}' +class ZeroPad1D(Layer): + """ + This interface is used to construct a callable object of the ``ZeroPad1D`` class. + Pads the input tensor boundaries with zero. + + Parameters: + padding (Tensor | List[int] | int): The padding size with data type int. If is int, use the + same padding in all dimensions. Else [len(padding)/2] dimensions of input will be padded. + The pad has the form (pad_left, pad_right). + data_format (str): An string from: "NCL", "NLC". Specify the data format of the input data. + Default is "NCL" + name (str, optional) : The default value is None. Normally there is no need for + user to set this property. For more information, please refer to :ref:`api_guide_Name`. + + Shape: + - x(Tensor): The input tensor of zeropad1d operator, which is a 3-D tensor. + The data type can be float32, float64. + - output(Tensor): The output tensor of zeropad1d operator, which is a 3-D tensor. + The data type is same as input x. + + Examples: + + .. code-block:: python + + >>> import paddle + >>> import paddle.nn as nn + + >>> input_shape = (1, 2, 3) + >>> pad = [1, 2] + >>> data = paddle.arange(paddle.prod(paddle.to_tensor(input_shape)), dtype="float32").reshape(input_shape) + 1 + >>> my_pad = nn.ZeroPad1D(padding=pad) + >>> result = my_pad(data) + >>> print(result) + Tensor(shape=[1, 2, 6], dtype=float32, place=Place(cpu), stop_gradient=True, + [[[0., 1., 2., 3., 0., 0.], + [0., 4., 5., 6., 0., 0.]]]) + """ + + def __init__(self, padding, data_format="NCL", name=None): + super().__init__() + self._pad = _npairs(padding, 1) + self._mode = 'constant' + self._value = 0.0 + self._data_format = data_format + self._name = name + + def forward(self, x): + return F.pad( + x, + pad=self._pad, + mode=self._mode, + value=self._value, + data_format=self._data_format, + name=self._name, + ) + + def extra_repr(self): + name_str = f', name={self._name}' if self._name else '' + return f'padding={self._pad}, data_format={self._data_format}{name_str}' + + class Pad2D(Layer): """ This interface is used to construct a callable object of the ``Pad2D`` class. @@ -1290,6 +1351,70 @@ def extra_repr(self): return f'padding={self._pad}, mode={self._mode}, value={self._value}, data_format={self._data_format}{name_str}' +class ZeroPad3D(Layer): + """ + This interface is used to construct a callable object of the ``ZeroPad3D`` class. + Pads the input tensor boundaries with zero. + + Parameters: + padding (Tensor | List[int] | int): The padding size with data type int. If is int, use the + same padding in all dimensions. Else [len(padding)/2] dimensions of input will be padded. + The pad has the form (pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back). + data_format (str): An string from: "NCDHW", "NDHWC". Specify the data format of the input data. + Default is "NCDHW" + name (str, optional) : The default value is None. Normally there is no need for + user to set this property. For more information, please refer to :ref:`api_guide_Name`. + + Shape: + - x(Tensor): The input tensor of zeropad3d operator, which is a 5-D tensor. + The data type can be float32, float64. + - output(Tensor): The output tensor of zeropad3d operator, which is a 5-D tensor. + The data type is same as input x. + + Examples: + + .. code-block:: python + + >>> import paddle + >>> import paddle.nn as nn + + >>> input_shape = (1, 1, 1, 2, 3) + >>> pad = [1, 0, 1, 2, 0, 0] + >>> data = paddle.arange(paddle.prod(paddle.to_tensor(input_shape)), dtype="float32").reshape(input_shape) + 1 + >>> my_pad = nn.ZeroPad3D(padding=pad) + >>> result = my_pad(data) + >>> print(result) + Tensor(shape=[1, 1, 1, 5, 4], dtype=float32, place=Place(cpu), stop_gradient=True, + [[[[[0., 0., 0., 0.], + [0., 1., 2., 3.], + [0., 4., 5., 6.], + [0., 0., 0., 0.], + [0., 0., 0., 0.]]]]]) + """ + + def __init__(self, padding, data_format="NCDHW", name=None): + super().__init__() + self._pad = _npairs(padding, 3) + self._mode = 'constant' + self._value = 0.0 + self._data_format = data_format + self._name = name + + def forward(self, x): + return F.pad( + x, + pad=self._pad, + mode=self._mode, + value=self._value, + data_format=self._data_format, + name=self._name, + ) + + def extra_repr(self): + name_str = f', name={self._name}' if self._name else '' + return f'padding={self._pad}, data_format={self._data_format}{name_str}' + + class CosineSimilarity(Layer): """ This interface is used to compute cosine similarity between x1 and x2 along axis. diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 53f409e4889714..bc4233bfba8a13 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -142,6 +142,7 @@ atleast_1d, atleast_2d, atleast_3d, + block_diag, broadcast_tensors, broadcast_to, cast, @@ -543,6 +544,7 @@ 'hypot_', 'nansum', 'nanmean', + 'block_diag', 'count_nonzero', 'tanh', 'tanh_', diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 9803d4a8c5c0a8..dc26bc13537293 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -6861,3 +6861,67 @@ def slice_scatter(x, value, axes, starts, ends, strides, name=None): ) return output + + +def block_diag(inputs, name=None): + """ + Create a block diagonal matrix from provided tensors. + + Args: + inputs (list|tuple): ``inputs`` is a Tensor list or Tensor tuple, one or more tensors with 0, 1, or 2 dimensions. + name (str, optional): Name for the operation (optional, default is None). + + Returns: + Tensor, A ``Tensor``. The data type is same as ``inputs``. + + Examples: + .. code-block:: python + + >>> import paddle + + >>> A = paddle.to_tensor([[4], [3], [2]]) + >>> B = paddle.to_tensor([7, 6, 5]) + >>> C = paddle.to_tensor(1) + >>> D = paddle.to_tensor([[5, 4, 3], [2, 1, 0]]) + >>> E = paddle.to_tensor([[8, 7], [7, 8]]) + >>> out = paddle.block_diag([A, B, C, D, E]) + >>> print(out) + Tensor(shape=[9, 10], dtype=int64, place=Place(gpu:0), stop_gradient=True, + [[4, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [3, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [2, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 7, 6, 5, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 5, 4, 3, 0, 0], + [0, 0, 0, 0, 0, 2, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 8, 7], + [0, 0, 0, 0, 0, 0, 0, 0, 7, 8]]) + """ + + def to_col_block(arys, i, a): + return [ + a + if idx == i + else paddle.zeros([ary.shape[0], a.shape[1]], dtype=a.dtype) + for idx, ary in enumerate(arys) + ] + + def to_2d(ary): + if ary.ndim == 0: + return ary.unsqueeze(axis=0).unsqueeze(axis=0) + if ary.ndim == 1: + return ary.unsqueeze(axis=0) + if ary.ndim == 2: + return ary + raise ValueError( + "For 'block_diag', the dimension of each elements in 'inputs' must be 0, 1, or 2, but got " + f"{ary.ndim}" + ) + + arys = [to_2d(ary) for ary in inputs] + + matrix = [ + paddle.concat(to_col_block(arys, idx, ary), axis=0) + for idx, ary in enumerate(arys) + ] + return paddle.concat(matrix, axis=1) diff --git a/test/legacy_test/test_ZeroPad1d.py b/test/legacy_test/test_ZeroPad1d.py new file mode 100644 index 00000000000000..31baf6a7cf2468 --- /dev/null +++ b/test/legacy_test/test_ZeroPad1d.py @@ -0,0 +1,90 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +from paddle import to_tensor +from paddle.nn import ZeroPad1D + + +class TestZeroPad1dAPI(unittest.TestCase): + def setUp(self): + if paddle.is_compiled_with_cuda(): + paddle.device.set_device('gpu:0') + else: + paddle.device.set_device('cpu') + self.shape = [4, 6, 6] + self.support_dtypes = ['float32', 'float64', 'int32', 'int64'] + + def test_support_dtypes(self): + for dtype in self.support_dtypes: + pad = 2 + x = np.random.randint(-255, 255, size=self.shape).astype(dtype) + expect_res = np.pad( + x, + [[0, 0], [0, 0], [pad, pad]], + mode='constant', + constant_values=0, + ) + + x_tensor = to_tensor(x).astype(dtype) + zeropad1d = ZeroPad1D(padding=pad) + ret_res = zeropad1d(x_tensor).numpy() + np.testing.assert_allclose(expect_res, ret_res, rtol=1e-05) + + def test_support_pad2(self): + pad = [1, 2] + x = np.random.randint(-255, 255, size=self.shape) + expect_res = np.pad( + x, [[0, 0], [0, 0], pad], mode='constant', constant_values=0 + ) + + x_tensor = to_tensor(x) + zeropad1d = ZeroPad1D(padding=pad) + ret_res = zeropad1d(x_tensor).numpy() + np.testing.assert_allclose(expect_res, ret_res, rtol=1e-05) + + def test_support_pad3(self): + pad = (1, 2) + x = np.random.randint(-255, 255, size=self.shape) + expect_res = np.pad(x, [[0, 0], [0, 0], [pad[0], pad[1]]]) + + x_tensor = to_tensor(x) + zeropad1d = ZeroPad1D(padding=pad) + ret_res = zeropad1d(x_tensor).numpy() + np.testing.assert_allclose(expect_res, ret_res, rtol=1e-05) + + def test_support_pad4(self): + pad = [1, 2] + x = np.random.randint(-255, 255, size=self.shape) + expect_res = np.pad(x, [[0, 0], [0, 0], [pad[0], pad[1]]]) + + x_tensor = to_tensor(x) + pad_tensor = to_tensor(pad, dtype='int32') + zeropad1d = ZeroPad1D(padding=pad_tensor) + ret_res = zeropad1d(x_tensor).numpy() + np.testing.assert_allclose(expect_res, ret_res, rtol=1e-05) + + def test_repr(self): + pad = [1, 2] + zeropad1d = ZeroPad1D(padding=pad) + name_str = zeropad1d.extra_repr() + assert name_str == 'padding=[1, 2], data_format=NCL' + + +if __name__ == '__main__': + unittest.main() diff --git a/test/legacy_test/test_ZeroPad3d.py b/test/legacy_test/test_ZeroPad3d.py new file mode 100644 index 00000000000000..8cc7a45c959df8 --- /dev/null +++ b/test/legacy_test/test_ZeroPad3d.py @@ -0,0 +1,117 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +from paddle import to_tensor +from paddle.nn import ZeroPad3D + + +class TestZeroPad3DAPI(unittest.TestCase): + def setUp(self): + if paddle.is_compiled_with_cuda(): + paddle.device.set_device('gpu:0') + else: + paddle.device.set_device('cpu') + self.shape = [4, 3, 6, 6, 6] + self.support_dtypes = ['float32', 'float64', 'int32', 'int64'] + + def test_support_dtypes(self): + for dtype in self.support_dtypes: + pad = 2 + x = np.random.randint(-255, 255, size=self.shape).astype(dtype) + expect_res = np.pad( + x, + [[0, 0], [0, 0], [pad, pad], [pad, pad], [pad, pad]], + mode='constant', + constant_values=0, + ) + + x_tensor = to_tensor(x).astype(dtype) + zeropad3d = ZeroPad3D(padding=pad) + ret_res = zeropad3d(x_tensor).numpy() + np.testing.assert_allclose(expect_res, ret_res, rtol=1e-05) + + def test_support_pad2(self): + pad = [1, 2, 3, 4, 5, 6] + x = np.random.randint(-255, 255, size=self.shape) + expect_res = np.pad( + x, + [ + [0, 0], + [0, 0], + [pad[4], pad[5]], + [pad[2], pad[3]], + [pad[0], pad[1]], + ], + mode='constant', + constant_values=0, + ) + + x_tensor = to_tensor(x) + zeropad3d = ZeroPad3D(padding=pad) + ret_res = zeropad3d(x_tensor).numpy() + np.testing.assert_allclose(expect_res, ret_res, rtol=1e-05) + + def test_support_pad3(self): + pad = (1, 2, 3, 4, 5, 6) + x = np.random.randint(-255, 255, size=self.shape) + expect_res = np.pad( + x, + [ + [0, 0], + [0, 0], + [pad[4], pad[5]], + [pad[2], pad[3]], + [pad[0], pad[1]], + ], + ) + + x_tensor = to_tensor(x) + zeropad3d = ZeroPad3D(padding=pad) + ret_res = zeropad3d(x_tensor).numpy() + np.testing.assert_allclose(expect_res, ret_res, rtol=1e-05) + + def test_support_pad4(self): + pad = [1, 2, 3, 4, 5, 6] + x = np.random.randint(-255, 255, size=self.shape) + expect_res = np.pad( + x, + [ + [0, 0], + [0, 0], + [pad[4], pad[5]], + [pad[2], pad[3]], + [pad[0], pad[1]], + ], + ) + + x_tensor = to_tensor(x) + pad_tensor = to_tensor(pad, dtype='int32') + zeropad3d = ZeroPad3D(padding=pad_tensor) + ret_res = zeropad3d(x_tensor).numpy() + np.testing.assert_allclose(expect_res, ret_res, rtol=1e-05) + + def test_repr(self): + pad = pad = [1, 2, 3, 4, 5, 6] + zeropad3d = ZeroPad3D(padding=pad) + name_str = zeropad3d.extra_repr() + assert name_str == 'padding=[1, 2, 3, 4, 5, 6], data_format=NCDHW' + + +if __name__ == '__main__': + unittest.main() diff --git a/test/legacy_test/test_block_diag.py b/test/legacy_test/test_block_diag.py new file mode 100644 index 00000000000000..842f360f33c4b7 --- /dev/null +++ b/test/legacy_test/test_block_diag.py @@ -0,0 +1,95 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import scipy + +import paddle +from paddle import base + + +class TestBlockDiagError(unittest.TestCase): + def test_errors(self): + def test_type_error(): + A = np.array([[1, 2], [3, 4]]) + B = np.array([[5, 6], [7, 8]]) + C = np.array([[9, 10], [11, 12]]) + with paddle.static.program_guard(base.Program()): + out = paddle.block_diag([A, B, C]) + + self.assertRaises(TypeError, test_type_error) + + def test_dime_error(): + A = paddle.to_tensor([[[1, 2], [3, 4]]]) + B = paddle.to_tensor([[[5, 6], [7, 8]]]) + C = paddle.to_tensor([[[9, 10], [11, 12]]]) + with paddle.static.program_guard(base.Program()): + out = paddle.block_diag([A, B, C]) + + self.assertRaises(ValueError, test_dime_error) + + +class TestBlockDiag(unittest.TestCase): + def setUp(self): + paddle.seed(2024) + self.type_list = ['int32', 'int64', 'float32', 'float64'] + self.place = [('cpu', paddle.CPUPlace())] + ( + [('gpu', paddle.CUDAPlace(0))] + if paddle.is_compiled_with_cuda() + else [] + ) + + def test_dygraph(self): + paddle.disable_static() + for device, place in self.place: + paddle.set_device(device) + for i in self.type_list: + A = np.random.randn(2, 3).astype(i) + B = np.random.randn(2).astype(i) + C = np.random.randn(4, 1).astype(i) + s_out = scipy.linalg.block_diag(A, B, C) + + A_tensor = paddle.to_tensor(A) + B_tensor = paddle.to_tensor(B) + C_tensor = paddle.to_tensor(C) + out = paddle.block_diag([A_tensor, B_tensor, C_tensor]) + np.testing.assert_allclose(out.numpy(), s_out) + + def test_static(self): + paddle.enable_static() + for device, place in self.place: + paddle.set_device(device) + for i in self.type_list: + A = np.random.randn(2, 3).astype(i) + B = np.random.randn(2).astype(i) + C = np.random.randn(4, 1).astype(i) + s_out = scipy.linalg.block_diag(A, B, C) + + with paddle.static.program_guard(paddle.static.Program()): + A_tensor = paddle.static.data('A', [2, 3], i) + B_tensor = paddle.static.data('B', [2], i) + C_tensor = paddle.static.data('C', [4, 1], i) + out = paddle.block_diag([A_tensor, B_tensor, C_tensor]) + exe = paddle.static.Executor(place) + res = exe.run( + feed={'A': A, 'B': B, 'C': C}, + fetch_list=[out], + ) + np.testing.assert_allclose(res[0], s_out) + + +if __name__ == '__main__': + unittest.main()