diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index df41a00202f2b3..0bf2475e5036d5 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -156,6 +156,7 @@ cholesky, bmm, histogram, + histogramdd, bincount, mv, eigvalsh, @@ -694,6 +695,7 @@ 'rot90', 'bincount', 'histogram', + 'histogramdd', 'multiplex', 'CUDAPlace', 'empty', diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index dc69a4e58bb9a8..baa122d5008965 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -79,6 +79,7 @@ eigvals, eigvalsh, histogram, + histogramdd, householder_product, lstsq, lu, @@ -434,6 +435,7 @@ 'cholesky', 'bmm', 'histogram', + 'histogramdd', 'bincount', 'mv', 'matrix_power', diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 212125825d9b13..8625c5ae1ecdef 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -29,6 +29,7 @@ __all__ = [] + # Consistent with kDefaultDim from C++ Backend K_DEFAULT_DIM = 9 @@ -3869,3 +3870,227 @@ def _householder_product(x, tau): ) out = out.reshape(org_x_shape) return out + + +def histogramdd( + x, bins=10, ranges=None, density=False, weights=None, name=None +): + r""" + Computes a multi-dimensional histogram of the values in a tensor. + + Interprets the elements of an input tensor whose innermost dimension has size `N` as a collection of N-dimensional points. Maps each of the points into a set of N-dimensional bins and returns the number of points (or total weight) in each bin. + + input `x` must be a tensor with at least 2 dimensions. If input has shape `(M, N)`, each of its `M` rows defines a point in N-dimensional space. If input has three or more dimensions, all but the last dimension are flattened. + + Each dimension is independently associated with its own strictly increasing sequence of bin edges. Bin edges may be specified explicitly by passing a sequence of 1D tensors. Alternatively, bin edges may be constructed automatically by passing a sequence of integers specifying the number of equal-width bins in each dimension. + + Args: + x (Tensor): The input tensor. + bins (Tensor[], int[], or int): If Tensor[], defines the sequences of bin edges. If int[], defines the number of equal-width bins in each dimension. If int, defines the number of equal-width bins for all dimensions. + ranges (sequence of float, optional): Defines the leftmost and rightmost bin edges in each dimension. If is None, set the minimum and maximum as leftmost and rightmost edges for each dimension. + density (bool, optional): If False (default), the result will contain the count (or total weight) in each bin. If True, each count (weight) is divided by the total count (total weight), then divided by the volume of its associated bin. + weights (Tensor, optional): By default, each value in the input has weight 1. If a weight tensor is passed, each N-dimensional coordinate in input contributes its associated weight towards its bin’s result. The weight tensor should have the same shape as the input tensor excluding its innermost dimension N. + name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. + + Returns: + N-dimensional Tensor containing the values of the histogram. ``bin_edges(Tensor[])``, sequence of N 1D Tensors containing the bin edges. + + Examples: + .. code-block:: python + :name: exampl + + >>> import paddle + >>> x = paddle.to_tensor([[0., 1.], [1., 0.], [2.,0.], [2., 2.]]) + >>> bins = [3,3] + >>> weights = paddle.to_tensor([1., 2., 4., 8.]) + >>> paddle.histogramdd(x, bins=bins, weights=weights) + (Tensor(shape=[3, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True, + [[0., 1., 0.], + [2., 0., 0.], + [4., 0., 8.]]), [Tensor(shape=[4], dtype=float32, place=Place(gpu:0), stop_gradient=True, + [0. , 0.66666669, 1.33333337, 2. ]), Tensor(shape=[4], dtype=float32, place=Place(gpu:0), stop_gradient=True, + [0. , 0.66666669, 1.33333337, 2. ])]) + + .. code-block:: python + :name: examp2 + + >>> import paddle + >>> y = paddle.to_tensor([[0., 0.], [1., 1.], [2., 2.]]) + >>> bins = [2,2] + >>> ranges = [0., 1., 0., 1.] + >>> density = True + >>> paddle.histogramdd(y, bins=bins, ranges=ranges, density=density) + (Tensor(shape=[2, 2], dtype=float32, place=Place(gpu:0), stop_gradient=True, + [[2., 0.], + [0., 2.]]), [Tensor(shape=[3], dtype=float32, place=Place(gpu:0), stop_gradient=True, + [0. , 0.50000000, 1. ]), Tensor(shape=[3], dtype=float32, place=Place(gpu:0), stop_gradient=True, + [0. , 0.50000000, 1. ])]) + + + """ + + def __check_x(x): + assert ( + len(x.shape) >= 2 + ), "input x must be a tensor with at least 2 dimensions." + check_variable_and_dtype( + x, + 'x', + [ + 'float32', + 'float64', + ], + 'histogramdd', + ) + + def __check_bins(bins, x): # when Tensor[], check dtype + for bins_tensor in bins: + bins_tensor = paddle.to_tensor(bins_tensor) + check_variable_and_dtype( + bins_tensor, + 'bins', + [ + 'float32', + 'float64', + ], + 'histogramdd', + ) + assert ( + bins_tensor.dtype == x.dtype + ), "When bins is Tensor[], the dtype of bins must be the same as x.\n" + + def __check_weights(x, weights): + if weights is None: + return + x_shape, weights_shape = x.shape, weights.shape + assert len(x_shape) == len(weights_shape) + 1, ( + "if weight tensor is provided," + "it should have the same shape as the input tensor excluding its innermost dimension.\n" + ) + for i, _ in enumerate(weights_shape): + assert weights_shape[i] == x_shape[i], ( + "if weight tensor is provided," + "it should have the same shape as the input tensor excluding its innermost dimension.\n" + ) + check_variable_and_dtype( + weights, + 'weights', + [ + 'float32', + 'float64', + ], + 'histogramdd', + ) + assert ( + weights.dtype == x.dtype + ), "The dtype of weights must be the same as x.\n" + + def __check_ranges(D, ranges): + if ranges is None: + return + check_type(ranges, 'ranges', (list, tuple), 'histogramdd') + assert D * 2 == len( + ranges + ), "The length of ranges list must be %d\n" % (D * 2) + + check_type(density, 'density', bool, 'histogramdd') + + __check_x(x) + # weights + __check_weights(x, weights) + D = x.shape[-1] + reshaped_input = x.reshape([-1, D]) + N = reshaped_input.shape[0] + reshaped_weights = None + if weights is not None: + weights = weights.astype(x.dtype) + reshaped_weights = weights.reshape([N]) + assert reshaped_weights.shape[0] == N, ( + "The size of weight must be %d" % N + ) + # ranges + __check_ranges(D, ranges) + if ranges is None: + ranges = paddle.zeros([D, 2], dtype=x.dtype) + maxv = paddle.max(reshaped_input, axis=0).reshape([-1]) + minv = paddle.min(reshaped_input, axis=0).reshape([-1]) + + if paddle.in_dynamic_mode(): + ranges[:, 0] = minv + ranges[:, 1] = maxv + else: + ranges = paddle.static.setitem(ranges, (slice(None), 0), minv) + ranges = paddle.static.setitem(ranges, (slice(None), 1), maxv) + else: + ranges = paddle.to_tensor(ranges, dtype=x.dtype).reshape([D, 2]) + # bins to edges + edges = [] + hist_shape = [] + dedges = [] + if isinstance(bins, (int, list)): # int or int[] + if isinstance(bins, int): + bins = [bins] * D + assert len(bins) == D, ( + "The length of bins must be %d when bins is a list.\n" % D + ) + for idx, r in enumerate(ranges): + if not isinstance(bins[idx], int): + raise ValueError( + "The type of %d-th element in bins list must be int." % idx + ) + e = paddle.linspace(r[0], r[1], bins[idx] + 1, x.dtype) + edges.append(e) + dedges.append(e.diff()) + elif isinstance( + bins, tuple + ): # tuple with D tensors for each innermost dimension + __check_bins(bins, x) + for bin in bins: + bin = paddle.to_tensor(bin) + edges.append(bin) + dedges.append(bin.diff()) + else: + raise ValueError("Input bins must be Tensor[], int[], or int.") + hist_shape = [edge.shape[0] + 1 for edge in edges] + index_list = [] + # edges shape: [D, linspaced] + # index_list shape: [D, N] + for idx, edge in enumerate(edges): + edge = paddle.to_tensor(edge) + index_list.append( + paddle.searchsorted(edge, reshaped_input[:, idx], right=True) + ) + index_list = paddle.to_tensor(index_list) + for i in range(D): + on_edge = reshaped_input[:, i] == edges[i][-1] + if paddle.in_dynamic_mode(): + index_list[i][on_edge] -= 1 + else: + index_list = paddle.static.setitem( + index_list, (i, on_edge), index_list[i][on_edge] - 1 + ) + index_list = tuple(index_list) + lut = paddle.arange( + paddle.to_tensor(hist_shape).prod(), + ).reshape(hist_shape) + flattened_index = lut[index_list] + hist = paddle.bincount( + flattened_index, + reshaped_weights, + minlength=paddle.to_tensor(hist_shape).prod(), + ) + hist = hist.reshape(hist_shape) + hist = hist.astype(x.dtype) + + core = D * (slice(1, -1),) + hist = hist[core] + + if density: + s = hist.sum() + for i in range(D): + shape = D * [1] + shape[i] = hist_shape[i] - 2 + hist = hist / dedges[i].reshape(shape) + hist /= s + + return (hist, edges) diff --git a/test/legacy_test/test_histogramdd_op.py b/test/legacy_test/test_histogramdd_op.py new file mode 100644 index 00000000000000..482788f61e8ca1 --- /dev/null +++ b/test/legacy_test/test_histogramdd_op.py @@ -0,0 +1,488 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle + + +def ref_histogramdd(x, bins, ranges, weights, density): + D = x.shape[-1] + x = x.reshape(-1, D) + if ranges is not None: + ranges = np.array(ranges, dtype=x.dtype).reshape(D, 2).tolist() + if weights is not None: + weights = weights.reshape(-1) + ref_hist, ref_edges = np.histogramdd(x, bins, ranges, density, weights) + return ref_hist, ref_edges + + +# inputs, bins, ranges, weights, density +class TestHistogramddAPI(unittest.TestCase): + def setUp(self): + self.ranges = None + self.weights = None + self.density = False + + self.init_input() + self.set_expect_output() + self.place = ( + paddle.CUDAPlace(0) + if paddle.is_compiled_with_cuda() + else paddle.CPUPlace() + ) + + def init_input(self): + # self.sample = np.array([[0.0, 1.0], [1.0, 0.0], [2.0, 0.0], [2.0, 2.0]]) + self.sample = np.random.randn( + 4, + 2, + ).astype(np.float64) + self.bins = [3, 3] + self.weights = np.array([1.0, 2.0, 4.0, 8.0], dtype=self.sample.dtype) + + def set_expect_output(self): + self.expect_hist, self.expect_edges = ref_histogramdd( + self.sample, self.bins, self.ranges, self.weights, self.density + ) + + def test_static_api(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data( + 'x', self.sample.shape, dtype=self.sample.dtype + ) + if self.weights is not None: + weights = paddle.static.data( + 'weights', self.weights.shape, dtype=self.weights.dtype + ) + out_0, out_1 = paddle.histogramdd( + x, + bins=self.bins, + weights=weights, + ranges=self.ranges, + density=self.density, + ) + else: + out_0, out_1 = paddle.histogramdd( + x, bins=self.bins, ranges=self.ranges, density=self.density + ) + exe = paddle.static.Executor(self.place) + if self.weights is not None: + res = exe.run( + feed={'x': self.sample, 'weights': self.weights}, + fetch_list=[out_0, out_1], + ) + else: + res = exe.run( + feed={'x': self.sample}, fetch_list=[out_0, out_1] + ) + + hist_out, edges_out = res[0], res[1:] + np.testing.assert_allclose( + hist_out, + self.expect_hist, + ) + for idx, edge_out in enumerate(edges_out): + expect_edge = np.array(self.expect_edges[idx]) + np.testing.assert_allclose( + edge_out, + expect_edge, + ) + + def test_dygraph_api(self): + paddle.disable_static(self.place) + self.sample_dy = paddle.to_tensor(self.sample, dtype=self.sample.dtype) + self.weights_dy = None + if self.weights is not None: + self.weights_dy = paddle.to_tensor(self.weights) + if isinstance(self.bins, tuple): + self.bins = tuple([paddle.to_tensor(bin) for bin in self.bins]) + hist, edges = paddle.histogramdd( + self.sample_dy, + bins=self.bins, + weights=self.weights_dy, + ranges=self.ranges, + density=self.density, + ) + + np.testing.assert_allclose( + hist.numpy(), + self.expect_hist, + ) + for idx, edge in enumerate(edges): + edge = edge.numpy() + expect_edge = np.array(self.expect_edges[idx]) + np.testing.assert_allclose( + edge, + expect_edge, + ) + + paddle.enable_static() + + def test_error(self): + pass + + +class TestHistogramddAPICase1ForDensity(TestHistogramddAPI): + def init_input(self): + # self.sample = np.array([[0.0, 0.0], [1.0, 1.0], [2.0, 2.0]]) + self.sample = np.random.randn(4, 2).astype(np.float64) + self.bins = [2, 2] + self.ranges = [0.0, 1.0, 0.0, 1.0] + self.density = True + + def set_expect_output(self): + self.expect_hist, self.expect_edges = ref_histogramdd( + self.sample, self.bins, self.ranges, self.weights, self.density + ) + + +class TestHistogramddAPICase2ForMultiDimsAndDensity(TestHistogramddAPI): + def init_input(self): + # shape: [4,2,2] + self.sample = np.random.randn(4, 2, 2).astype(np.float64) + self.bins = [3, 4] + self.density = True + + def set_expect_output(self): + self.expect_hist, self.expect_edges = ref_histogramdd( + self.sample, self.bins, self.ranges, self.weights, self.density + ) + + +class TestHistogramddAPICase3ForMultiDimsNotDensity(TestHistogramddAPI): + def init_input(self): + # shape: [4,2,2] + self.sample = np.random.randn(4, 2, 2).astype(np.float64) + self.bins = [3, 4] + # self.density = True + + def set_expect_output(self): + self.expect_hist, self.expect_edges = ref_histogramdd( + self.sample, self.bins, self.ranges, self.weights, self.density + ) + + +class TestHistogramddAPICase4ForRangesAndDensity(TestHistogramddAPI): + def init_input(self): + # shape: [4,2,2] + self.sample = np.random.randn(4, 2, 2).astype(np.float64) + self.bins = [3, 4] + # [leftmost_1, rightmost_1, leftmost_2, rightmost_2,..., leftmost_D, rightmost_D] + self.ranges = [1.0, 10.0, 1.0, 100.0] + self.density = True + + def set_expect_output(self): + self.expect_hist, self.expect_edges = ref_histogramdd( + self.sample, self.bins, self.ranges, self.weights, self.density + ) + + +class TestHistogramddAPICase5ForRangesNotDensity(TestHistogramddAPI): + def init_input(self): + # shape: [4,2,2] + self.sample = np.random.randn(4, 2, 2).astype(np.float64) + self.bins = [3, 4] + # [leftmost_1, rightmost_1, leftmost_2, rightmost_2,..., leftmost_D, rightmost_D] + self.ranges = [1.0, 10.0, 1.0, 100.0] + # self.density = True + + def set_expect_output(self): + self.expect_hist, self.expect_edges = ref_histogramdd( + self.sample, self.bins, self.ranges, self.weights, self.density + ) + + +class TestHistogramddAPICase6NotRangesAndDensityAndWeights(TestHistogramddAPI): + def init_input(self): + # shape: [4,2,2] + self.sample = np.random.randn(4, 2, 2).astype(np.float64) + self.bins = [3, 4] + # [leftmost_1, rightmost_1, leftmost_2, rightmost_2,..., leftmost_D, rightmost_D] + # self.ranges = [1., 10., 1., 100.] + self.density = True + self.weights = np.array( + [ + [1.0, 2.0], + [3.0, 4.0], + [5.0, 6.0], + [7.0, 8.0], + ], + dtype=self.sample.dtype, + ) + + def set_expect_output(self): + self.expect_hist, self.expect_edges = ref_histogramdd( + self.sample, self.bins, self.ranges, self.weights, self.density + ) + + +class TestHistogramddAPICase7ForRangesAndDensityAndWeights(TestHistogramddAPI): + def init_input(self): + # shape: [4,2,2] + self.sample = np.random.randn(4, 2, 2).astype(np.float64) + self.bins = [3, 4] + # [leftmost_1, rightmost_1, leftmost_2, rightmost_2,..., leftmost_D, rightmost_D] + self.ranges = [1.0, 10.0, 1.0, 100.0] + self.density = True + self.weights = np.array( + [ + [1.0, 2.0], + [3.0, 4.0], + [5.0, 6.0], + [7.0, 8.0], + ], + dtype=self.sample.dtype, + ) + + def set_expect_output(self): + self.expect_hist, self.expect_edges = ref_histogramdd( + self.sample, self.bins, self.ranges, self.weights, self.density + ) + + +class TestHistogramddAPICase8MoreInnermostDim(TestHistogramddAPI): + def init_input(self): + # shape: [4,2,4] + self.sample = np.random.randn(4, 2, 4).astype(np.float64) + self.bins = [1, 2, 3, 4] + # [leftmost_1, rightmost_1, leftmost_2, rightmost_2,..., leftmost_D, rightmost_D] + self.density = False + self.weights = np.array( + [ + [1.0, 2.0], + [3.0, 4.0], + [5.0, 6.0], + [7.0, 8.0], + ], + dtype=self.sample.dtype, + ) + + def set_expect_output(self): + self.expect_hist, self.expect_edges = ref_histogramdd( + self.sample, self.bins, self.ranges, self.weights, self.density + ) + + +class TestHistogramddAPICase8MoreInnermostDimAndDensity(TestHistogramddAPI): + def init_input(self): + # shape: [4,2,4] + self.sample = np.random.randn(4, 2, 4).astype(np.float64) + self.bins = [1, 2, 3, 4] + # [leftmost_1, rightmost_1, leftmost_2, rightmost_2,..., leftmost_D, rightmost_D] + self.density = True + self.weights = np.array( + [ + [1.0, 2.0], + [3.0, 4.0], + [5.0, 6.0], + [7.0, 8.0], + ], + dtype=self.sample.dtype, + ) + + def set_expect_output(self): + self.expect_hist, self.expect_edges = ref_histogramdd( + self.sample, self.bins, self.ranges, self.weights, self.density + ) + + +class TestHistogramddAPICase9ForIntBin(TestHistogramddAPI): + def init_input(self): + # shape: [4,2,2] + self.sample = np.random.randn(4, 2, 2).astype(np.float64) + self.weights = np.array( + [ + [1.0, 2.0], + [3.0, 4.0], + [5.0, 6.0], + [7.0, 8.0], + ], + dtype=self.sample.dtype, + ) + self.bins = 5 + self.density = True + self.ranges = [1.0, 10.0, 1.0, 100.0] + + def set_expect_output(self): + self.expect_hist, self.expect_edges = ref_histogramdd( + self.sample, self.bins, self.ranges, self.weights, self.density + ) + + +class TestHistogramddAPICase10ForTensorBin(TestHistogramddAPI): + def init_input(self): + # shape: [4,2,2] + self.sample = np.random.randn(4, 2, 2).astype(np.float64) + self.weights = np.array( + [ + [1.0, 2.0], + [3.0, 4.0], + [5.0, 6.0], + [7.0, 8.0], + ], + dtype=self.sample.dtype, + ) + self.bins = ( + np.array([1.0, 2.0, 10.0, 15.0, 20.0]), + np.array([0.0, 20.0, 100.0]), + ) + self.density = True + + def set_expect_output(self): + self.expect_hist, self.expect_edges = ref_histogramdd( + self.sample, self.bins, self.ranges, self.weights, self.density + ) + + +class TestHistogramddAPICase10ForFloat32(TestHistogramddAPI): + def init_input(self): + # self.sample = np.array([[0.0, 0.0], [1.0, 1.0], [2.0, 2.0]]) + self.sample = np.random.randn(4, 2).astype(np.float32) + self.bins = [2, 2] + self.ranges = [0.0, 1.0, 0.0, 1.0] + self.density = True + + def set_expect_output(self): + self.expect_hist, self.expect_edges = ref_histogramdd( + self.sample, self.bins, self.ranges, self.weights, self.density + ) + + +# histogramdd(sample, bins=10, ranges=None, density=False, weights=None, name=None): +class TestHistogramddAPI_check_sample_type_error(TestHistogramddAPI): + def test_error(self): + sample = paddle.to_tensor([[False, True], [True, False]]) + with self.assertRaises(TypeError): + paddle.histogramdd(sample) + + +class TestHistogramddAPI_check_bins_element_error(TestHistogramddAPI): + def test_error(self): + sample = paddle.to_tensor( + [ + [[1.0, 2.0], [3.0, 4.0]], + [[5.0, 6.0], [7.0, 8.0]], + [[9.0, 10.0], [11.0, 12.0]], + [[13.0, 14.0], [15.0, 16.0]], + ] + ) + bins = [3.4, 4.5] + with self.assertRaises(ValueError): + paddle.histogramdd(sample, bins=bins) + + +class TestHistogramddAPI_check_ranges_type_error(TestHistogramddAPI): + def test_error(self): + sample = paddle.to_tensor( + [ + [[1.0, 2.0], [3.0, 4.0]], + [[5.0, 6.0], [7.0, 8.0]], + [[9.0, 10.0], [11.0, 12.0]], + [[13.0, 14.0], [15.0, 16.0]], + ] + ) + ranges = 10 + with self.assertRaises(TypeError): + paddle.histogramdd(sample, ranges=ranges) + + +class TestHistogramddAPI_check_density_type_error(TestHistogramddAPI): + def test_error(self): + sample = paddle.to_tensor( + [ + [[1.0, 2.0], [3.0, 4.0]], + [[5.0, 6.0], [7.0, 8.0]], + [[9.0, 10.0], [11.0, 12.0]], + [[13.0, 14.0], [15.0, 16.0]], + ] + ) + density = 10 + with self.assertRaises(TypeError): + paddle.histogramdd(sample, density=density) + + +class TestHistogramddAPI_check_weights_type_error(TestHistogramddAPI): + def test_error(self): + sample = paddle.to_tensor( + [ + [[1.0, 2.0], [3.0, 4.0]], + [[5.0, 6.0], [7.0, 8.0]], + [[9.0, 10.0], [11.0, 12.0]], + [[13.0, 14.0], [15.0, 16.0]], + ] + ) + weights = 10 + with self.assertRaises(AttributeError): + paddle.histogramdd(sample, weights=weights) + + +class TestHistogramddAPI_sample_weights_shape_dismatch_error( + TestHistogramddAPI +): + def test_error(self): + sample = paddle.to_tensor( + [ # shape: [4,2] + [[1.0, 2.0], [3.0, 4.0]], + [[5.0, 6.0], [7.0, 8.0]], + [[9.0, 10.0], [11.0, 12.0]], + [[13.0, 14.0], [15.0, 16.0]], + ] + ) + weights = paddle.to_tensor( + [2.0, 3.0, 4.0], dtype=self.sample.dtype + ) # shape: [3,] + with self.assertRaises(AssertionError): + paddle.histogramdd(sample, weights=weights) + + +class TestHistogramddAPI_sample_weights_type_dismatch_error(TestHistogramddAPI): + def test_error(self): + sample = paddle.to_tensor( + [ # float32 + [[1.0, 2.0], [3.0, 4.0]], + [[5.0, 6.0], [7.0, 8.0]], + [[9.0, 10.0], [11.0, 12.0]], + [[13.0, 14.0], [15.0, 16.0]], + ], + dtype=paddle.float32, + ) + weights = paddle.to_tensor( + [2.0, 3.0, 4.0], dtype=paddle.float64 + ) # float64 + with self.assertRaises(AssertionError): + paddle.histogramdd(sample, weights=weights) + + +class TestHistogramddAPI_check_bins_type_error(TestHistogramddAPI): + def test_error(self): + sample = paddle.to_tensor( + [ + [[1.0, 2.0], [3.0, 4.0]], + [[5.0, 6.0], [7.0, 8.0]], + [[9.0, 10.0], [11.0, 12.0]], + [[13.0, 14.0], [15.0, 16.0]], + ] + ) + bins = 2.0 + with self.assertRaises(ValueError): + paddle.histogramdd(sample, bins=bins) + + +if __name__ == '__main__': + paddle.enable_static() + unittest.main()