diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 6f1929c64275b0..9d7af06e870ade 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -520,6 +520,10 @@ std, var, ) + +from .nn.functional.distance import ( # noqa: F401 + pdist, +) from .tensor.to_string import set_printoptions # CINN has to set a flag to include a lib @@ -695,6 +699,7 @@ 'sin_', 'dist', 'cdist', + 'pdist', 'unbind', 'meshgrid', 'arange', diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index d9b9e56210842a..95a648a3d9cd27 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -78,7 +78,8 @@ conv3d, conv3d_transpose, ) -from .distance import pairwise_distance + +from .distance import pairwise_distance, pdist from .extension import ( diag_embed, # noqa: F401 gather_tree, diff --git a/python/paddle/nn/functional/distance.py b/python/paddle/nn/functional/distance.py index 113df166a027a6..f14c220c0c0dbf 100644 --- a/python/paddle/nn/functional/distance.py +++ b/python/paddle/nn/functional/distance.py @@ -106,3 +106,40 @@ def pairwise_distance(x, y, p=2.0, epsilon=1e-6, keepdim=False, name=None): ) return out + + +def pdist(x, p=2.0, name=None): + r''' + Computes the p-norm distance between every pair of row vectors in the input. + + Args: + x (Tensor): The input tensor with shape :math:`N \times M`. + p (float, optional): The value for the p-norm distance to calculate between each vector pair. Default: :math:`2.0`. + name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. + + Returns: + Tensor with shape :math:`N(N-1)/2` , the dtype is same as input tensor. + + Examples: + .. code-block:: python + + >>> import paddle + >>> paddle.seed(2023) + >>> a = paddle.randn([4, 5]) + >>> print(a) + Tensor(shape=[4, 5], dtype=float32, place=Place(cpu), stop_gradient=True, + [[ 0.06132207, 1.11349595, 0.41906244, -0.24858207, -1.85169315], + [-1.50370061, 1.73954511, 0.13331604, 1.66359663, -0.55764782], + [-0.59911072, -0.57773495, -1.03176904, -0.33741450, -0.29695082], + [-1.50258386, 0.67233968, -1.07747352, 0.80170447, -0.06695852]]) + >>> pdist_out=paddle.pdist(a) + >>> print(pdist_out) + Tensor(shape=[6], dtype=float32, place=Place(cpu), stop_gradient=True, + [2.87295413, 2.79758120, 3.02793980, 3.40844536, 1.89435327, 1.93171620]) + ''' + + x_shape = list(x.shape) + assert len(x_shape) == 2, "The x must be 2-dimensional" + d = paddle.linalg.norm(x[..., None, :] - x[..., None, :, :], p=p, axis=-1) + mask = ~paddle.tril(paddle.ones(d.shape, dtype='bool')) + return paddle.masked_select(d, mask) diff --git a/test/legacy_test/test_pdist.py b/test/legacy_test/test_pdist.py new file mode 100644 index 00000000000000..503004de0a07c3 --- /dev/null +++ b/test/legacy_test/test_pdist.py @@ -0,0 +1,122 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle + + +def ref_pdist(x, p=2.0): + dist = np.linalg.norm(x[..., None, :] - x[None, :, :], ord=p, axis=-1) + res = [] + rows, cols = dist.shape + for i in range(rows): + for j in range(cols): + if i >= j: + continue + res.append(dist[i][j]) + return np.array(res) + + +class TestPdistAPI(unittest.TestCase): + def setUp(self): + self.x = np.random.rand(10, 20).astype('float32') + self.p = 2.0 + self.init_input() + self.place = ( + paddle.CUDAPlace(0) + if paddle.is_compiled_with_cuda() + else paddle.CPUPlace() + ) + + def init_input(self): + pass + + def test_static_api(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data('x', self.x.shape, dtype=self.x.dtype) + out = paddle.pdist( + x, + self.p, + ) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'x': self.x}, fetch_list=[out]) + out_ref = ref_pdist(self.x, self.p) + np.testing.assert_allclose(out_ref, res[0], rtol=1e-5, atol=1e-5) + + def test_dygraph_api(self): + paddle.disable_static(self.place) + x = paddle.to_tensor(self.x) + out = paddle.pdist( + x, + self.p, + ) + out_ref = ref_pdist(self.x, self.p) + np.testing.assert_allclose(out_ref, out.numpy(), rtol=1e-5, atol=1e-5) + paddle.enable_static() + + +class TestPdistAPICase1_param_p1(TestPdistAPI): + def init_input(self): + self.p = 0 + + +class TestPdistAPICase2_param_p2(TestPdistAPI): + def init_input(self): + self.p = 1.0 + + +class TestPdistAPICase3_param_p3(TestPdistAPI): + def init_input(self): + self.p = 3.0 + + +class TestPdistAPICase4_param_p4(TestPdistAPI): + def init_input(self): + self.p = 1.5 + + +class TestPdistAPICase5_param_p5(TestPdistAPI): + def init_input(self): + self.p = 2.5 + + +class TestPdistAPICase6_param_p6(TestPdistAPI): + def init_input(self): + self.p = float('inf') + + +class TestPdistAPICase7_input_x1(TestPdistAPI): + def init_input(self): + self.x = np.random.rand(50, 20).astype('float64') + + +class TestPdistShapeError(unittest.TestCase): + def test_error(self): + with self.assertRaises(AssertionError): + self.x = np.random.rand(50, 10, 20).astype('float64') + self.p = 2.0 + x = paddle.to_tensor(self.x) + out0 = paddle.pdist( + x, + self.p, + ) + + +if __name__ == '__main__': + paddle.enable_static() + unittest.main()