Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@
atleast_1d,
atleast_2d,
atleast_3d,
block_diag,
broadcast_tensors,
broadcast_to,
cast,
Expand Down Expand Up @@ -610,6 +611,7 @@
ir_guard._switch_to_pir()

__all__ = [
'block_diag',
'iinfo',
'finfo',
'dtype',
Expand Down
4 changes: 4 additions & 0 deletions python/paddle/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@
Upsample,
UpsamplingBilinear2D,
UpsamplingNearest2D,
ZeroPad1D,
ZeroPad2D,
ZeroPad3D,
)

# TODO: import all neural network related api under this directory,
Expand Down Expand Up @@ -298,4 +300,6 @@
'Unflatten',
'FractionalMaxPool2D',
'FractionalMaxPool3D',
'ZeroPad1D',
'ZeroPad3D',
]
125 changes: 125 additions & 0 deletions python/paddle/nn/layer/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1077,6 +1077,67 @@ def extra_repr(self):
return f'padding={self._pad}, mode={self._mode}, value={self._value}, data_format={self._data_format}{name_str}'


class ZeroPad1D(Layer):
"""
This interface is used to construct a callable object of the ``ZeroPad1D`` class.
Pads the input tensor boundaries with zero.

Parameters:
padding (Tensor | List[int] | int): The padding size with data type int. If is int, use the
same padding in all dimensions. Else [len(padding)/2] dimensions of input will be padded.
The pad has the form (pad_left, pad_right).
data_format (str): An string from: "NCL", "NLC". Specify the data format of the input data.
Default is "NCL"
name (str, optional) : The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`.

Shape:
- x(Tensor): The input tensor of zeropad1d operator, which is a 3-D tensor.
The data type can be float32, float64.
- output(Tensor): The output tensor of zeropad1d operator, which is a 3-D tensor.
The data type is same as input x.

Examples:

.. code-block:: python

>>> import paddle
>>> import paddle.nn as nn

>>> input_shape = (1, 2, 3)
>>> pad = [1, 2]
>>> data = paddle.arange(paddle.prod(paddle.to_tensor(input_shape)), dtype="float32").reshape(input_shape) + 1
>>> my_pad = nn.ZeroPad1D(padding=pad)
>>> result = my_pad(data)
>>> print(result)
Tensor(shape=[1, 2, 6], dtype=float32, place=Place(cpu), stop_gradient=True,
[[[0., 1., 2., 3., 0., 0.],
[0., 4., 5., 6., 0., 0.]]])
"""

def __init__(self, padding, data_format="NCL", name=None):
super().__init__()
self._pad = _npairs(padding, 1)
self._mode = 'constant'
self._value = 0.0
self._data_format = data_format
self._name = name

def forward(self, x):
return F.pad(
x,
pad=self._pad,
mode=self._mode,
value=self._value,
data_format=self._data_format,
name=self._name,
)

def extra_repr(self):
name_str = f', name={self._name}' if self._name else ''
return f'padding={self._pad}, data_format={self._data_format}{name_str}'


class Pad2D(Layer):
"""
This interface is used to construct a callable object of the ``Pad2D`` class.
Expand Down Expand Up @@ -1290,6 +1351,70 @@ def extra_repr(self):
return f'padding={self._pad}, mode={self._mode}, value={self._value}, data_format={self._data_format}{name_str}'


class ZeroPad3D(Layer):
"""
This interface is used to construct a callable object of the ``ZeroPad3D`` class.
Pads the input tensor boundaries with zero.

Parameters:
padding (Tensor | List[int] | int): The padding size with data type int. If is int, use the
same padding in all dimensions. Else [len(padding)/2] dimensions of input will be padded.
The pad has the form (pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back).
data_format (str): An string from: "NCDHW", "NDHWC". Specify the data format of the input data.
Default is "NCDHW"
name (str, optional) : The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`.

Shape:
- x(Tensor): The input tensor of zeropad3d operator, which is a 5-D tensor.
The data type can be float32, float64.
- output(Tensor): The output tensor of zeropad3d operator, which is a 5-D tensor.
The data type is same as input x.

Examples:

.. code-block:: python

>>> import paddle
>>> import paddle.nn as nn

>>> input_shape = (1, 1, 1, 2, 3)
>>> pad = [1, 0, 1, 2, 0, 0]
>>> data = paddle.arange(paddle.prod(paddle.to_tensor(input_shape)), dtype="float32").reshape(input_shape) + 1
>>> my_pad = nn.ZeroPad3D(padding=pad)
>>> result = my_pad(data)
>>> print(result)
Tensor(shape=[1, 1, 1, 5, 4], dtype=float32, place=Place(cpu), stop_gradient=True,
[[[[[0., 0., 0., 0.],
[0., 1., 2., 3.],
[0., 4., 5., 6.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]]]]])
"""

def __init__(self, padding, data_format="NCDHW", name=None):
super().__init__()
self._pad = _npairs(padding, 3)
self._mode = 'constant'
self._value = 0.0
self._data_format = data_format
self._name = name

def forward(self, x):
return F.pad(
x,
pad=self._pad,
mode=self._mode,
value=self._value,
data_format=self._data_format,
name=self._name,
)

def extra_repr(self):
name_str = f', name={self._name}' if self._name else ''
return f'padding={self._pad}, data_format={self._data_format}{name_str}'


class CosineSimilarity(Layer):
"""
This interface is used to compute cosine similarity between x1 and x2 along axis.
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@
atleast_1d,
atleast_2d,
atleast_3d,
block_diag,
broadcast_tensors,
broadcast_to,
cast,
Expand Down Expand Up @@ -543,6 +544,7 @@
'hypot_',
'nansum',
'nanmean',
'block_diag',
'count_nonzero',
'tanh',
'tanh_',
Expand Down
64 changes: 64 additions & 0 deletions python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6861,3 +6861,67 @@ def slice_scatter(x, value, axes, starts, ends, strides, name=None):
)

return output


def block_diag(inputs, name=None):
"""
Create a block diagonal matrix from provided tensors.

Args:
inputs (list|tuple): ``inputs`` is a Tensor list or Tensor tuple, one or more tensors with 0, 1, or 2 dimensions.
name (str, optional): Name for the operation (optional, default is None).

Returns:
Tensor, A ``Tensor``. The data type is same as ``inputs``.

Examples:
.. code-block:: python

>>> import paddle

>>> A = paddle.to_tensor([[4], [3], [2]])
>>> B = paddle.to_tensor([7, 6, 5])
>>> C = paddle.to_tensor(1)
>>> D = paddle.to_tensor([[5, 4, 3], [2, 1, 0]])
>>> E = paddle.to_tensor([[8, 7], [7, 8]])
>>> out = paddle.block_diag([A, B, C, D, E])
>>> print(out)
Tensor(shape=[9, 10], dtype=int64, place=Place(gpu:0), stop_gradient=True,
[[4, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[3, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[2, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 7, 6, 5, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 5, 4, 3, 0, 0],
[0, 0, 0, 0, 0, 2, 1, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 8, 7],
[0, 0, 0, 0, 0, 0, 0, 0, 7, 8]])
"""

def to_col_block(arys, i, a):
return [
a
if idx == i
else paddle.zeros([ary.shape[0], a.shape[1]], dtype=a.dtype)
for idx, ary in enumerate(arys)
]

def to_2d(ary):
if ary.ndim == 0:
return ary.unsqueeze(axis=0).unsqueeze(axis=0)
if ary.ndim == 1:
return ary.unsqueeze(axis=0)
if ary.ndim == 2:
return ary
raise ValueError(
"For 'block_diag', the dimension of each elements in 'inputs' must be 0, 1, or 2, but got "
f"{ary.ndim}"
)

arys = [to_2d(ary) for ary in inputs]

matrix = [
paddle.concat(to_col_block(arys, idx, ary), axis=0)
for idx, ary in enumerate(arys)
]
return paddle.concat(matrix, axis=1)
90 changes: 90 additions & 0 deletions test/legacy_test/test_ZeroPad1d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np

import paddle
from paddle import to_tensor
from paddle.nn import ZeroPad1D


class TestZeroPad1dAPI(unittest.TestCase):
def setUp(self):
if paddle.is_compiled_with_cuda():
paddle.device.set_device('gpu:0')
else:
paddle.device.set_device('cpu')
self.shape = [4, 6, 6]
self.support_dtypes = ['float32', 'float64', 'int32', 'int64']

def test_support_dtypes(self):
for dtype in self.support_dtypes:
pad = 2
x = np.random.randint(-255, 255, size=self.shape).astype(dtype)
expect_res = np.pad(
x,
[[0, 0], [0, 0], [pad, pad]],
mode='constant',
constant_values=0,
)

x_tensor = to_tensor(x).astype(dtype)
zeropad1d = ZeroPad1D(padding=pad)
ret_res = zeropad1d(x_tensor).numpy()
np.testing.assert_allclose(expect_res, ret_res, rtol=1e-05)

def test_support_pad2(self):
pad = [1, 2]
x = np.random.randint(-255, 255, size=self.shape)
expect_res = np.pad(
x, [[0, 0], [0, 0], pad], mode='constant', constant_values=0
)

x_tensor = to_tensor(x)
zeropad1d = ZeroPad1D(padding=pad)
ret_res = zeropad1d(x_tensor).numpy()
np.testing.assert_allclose(expect_res, ret_res, rtol=1e-05)

def test_support_pad3(self):
pad = (1, 2)
x = np.random.randint(-255, 255, size=self.shape)
expect_res = np.pad(x, [[0, 0], [0, 0], [pad[0], pad[1]]])

x_tensor = to_tensor(x)
zeropad1d = ZeroPad1D(padding=pad)
ret_res = zeropad1d(x_tensor).numpy()
np.testing.assert_allclose(expect_res, ret_res, rtol=1e-05)

def test_support_pad4(self):
pad = [1, 2]
x = np.random.randint(-255, 255, size=self.shape)
expect_res = np.pad(x, [[0, 0], [0, 0], [pad[0], pad[1]]])

x_tensor = to_tensor(x)
pad_tensor = to_tensor(pad, dtype='int32')
zeropad1d = ZeroPad1D(padding=pad_tensor)
ret_res = zeropad1d(x_tensor).numpy()
np.testing.assert_allclose(expect_res, ret_res, rtol=1e-05)

def test_repr(self):
pad = [1, 2]
zeropad1d = ZeroPad1D(padding=pad)
name_str = zeropad1d.extra_repr()
assert name_str == 'padding=[1, 2], data_format=NCL'


if __name__ == '__main__':
unittest.main()
Loading