From a5c540ec17e7aaae2c4d2e5d7ae65d262b714ddb Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Mon, 20 May 2024 11:42:48 +0800 Subject: [PATCH 01/11] fix median min dtype --- python/paddle/tensor/stat.py | 50 ++++++++++++++++++++++++++---- test/legacy_test/test_median.py | 54 +++++++++++++++++++++++++++------ 2 files changed, 89 insertions(+), 15 deletions(-) diff --git a/python/paddle/tensor/stat.py b/python/paddle/tensor/stat.py index 545b6bf21ca9e4..0a8365c3ec12e2 100644 --- a/python/paddle/tensor/stat.py +++ b/python/paddle/tensor/stat.py @@ -391,7 +391,7 @@ def median(x, axis=None, keepdim=False, mode='avg', name=None): Compute the median along the specified axis. Args: - x (Tensor): The input Tensor, it's data type can be bool, float16, float32, float64, int32, int64. + x (Tensor): The input Tensor, it's data type can be float32, float64, int32, int64. axis (int, optional): The axis along which to perform median calculations ``axis`` should be int. ``axis`` should be in range [-D, D), where D is the dimensions of ``x`` . If ``axis`` is less than 0, it works the same way as :math:`axis + D`. @@ -463,6 +463,28 @@ def median(x, axis=None, keepdim=False, mode='avg', name=None): >>> print(median_indices) Tensor(shape=[3], dtype=int64, place=Place(cpu), stop_gradient=True, [1, 1, 1]) + + >>> # cases containing nan values + >>> x = paddle.to_tensor(np.array([[1,2,3,float('nan')],[1,2,3,4],[float('nan'),1,2,3]]) + + >>> y6 = paddle.median(x, axis=-1, keepdim=True) + >>> print(76) + Tensor(shape=[3, 1], dtype=float64, place=Place(cpu), stop_gradient=True, + [[nan ], + [2.50000000], + [nan ]]) + + >>> median_value, median_indices = paddle.median(x, axis=1, keepdim=True, mode='min') + >>> print(median_value) + Tensor(shape=[3, 1], dtype=float64, place=Place(cpu), stop_gradient=True, + [[nan], + [2. ], + [nan]]) + >>> print(median_indices) + Tensor(shape=[3, 1], dtype=int64, place=Place(cpu), stop_gradient=True, + [[3], + [1], + [0]])) """ if not isinstance(x, (Variable, paddle.pir.Value)): raise TypeError("In median, the input x should be a Tensor.") @@ -521,6 +543,11 @@ def median(x, axis=None, keepdim=False, mode='avg', name=None): ), dtype=dtype, ) + out_tensor = out_tensor + paddle.sum( + paddle.cast(paddle.isnan(x), dtype=dtype) * x.astype(dtype), + axis=axis, + keepdim=True, + ) else: # mode == 'min' if sz & 1 == 0: out_tensor = paddle.slice( @@ -538,12 +565,23 @@ def median(x, axis=None, keepdim=False, mode='avg', name=None): out_idx = paddle.slice( idx, axes=[axis], starts=[kth], ends=[kth + 1] ) + # if contain nan on axis, return nan for that axis + out_tensor = out_tensor + paddle.sum( + paddle.cast(paddle.isnan(x), dtype=x.dtype) * x, + axis=axis, + keepdim=True, + ) + if need_idx: + # replace index using the first nan value's index on axis for out_idx + x_isnan = paddle.isnan(x) + contain_nan, nan_index = paddle.topk( + x_isnan.astype("int64"), k=1, axis=axis + ) + out_idx = ( + out_idx * (~contain_nan.astype("bool")).astype("int64") + + nan_index + ) - out_tensor = out_tensor + paddle.sum( - paddle.cast(paddle.isnan(x), dtype=dtype) * x.astype(dtype), - axis=axis, - keepdim=True, - ) if is_flatten: if keepdim: out_tensor = out_tensor.reshape([1] * dims) diff --git a/test/legacy_test/test_median.py b/test/legacy_test/test_median.py index ee38ef57f79c9e..f13b8e5f84a341 100644 --- a/test/legacy_test/test_median.py +++ b/test/legacy_test/test_median.py @@ -44,9 +44,10 @@ def np_medain_min(data, keepdims=False): np_res = data_sort[i] if keepdims: new_shape = [1] * len(shape) - return np_res.reshape(new_shape) - else: - return np_res + np_res = np_res.reshape(new_shape) + return np_res + np.sum( + np.isnan(data).astype(data.dtype) * data, axis=-1, keepdims=keepdims + ) def np_medain_min_axis(data, axis=None, keepdims=False): @@ -95,9 +96,12 @@ def np_medain_min_axis(data, axis=None, keepdims=False): if keepdims: shape = list(data.shape) shape[axis] = 1 - return np.reshape(np_res, shape) + np_res = np.reshape(np_res, shape) else: - return np.reshape(np_res, reshape[:-1]) + np_res = np.reshape(np_res, reshape[:-1]) + return np_res + np.sum( + np.isnan(data).astype(data.dtype) * data, axis=axis, keepdims=keepdims + ) class TestMedianAvg(unittest.TestCase): @@ -133,9 +137,10 @@ def test_median_static(self): l = 2 x = np.arange(h * w * l).reshape([h, w, l]) lis_tests = [ - [x, axis, keepdims] + [x.astype(dtype), axis, keepdims] for axis in [-1, 0, 1, 2, None] for keepdims in [False, True] + for dtype in ['float32', 'float64', 'int32', 'int64'] ] for lis_test in lis_tests: self.static_single_test_median(lis_test) @@ -147,9 +152,10 @@ def test_median_dygraph(self): l = 2 x = np.arange(h * w * l).reshape([h, w, l]) lis_tests = [ - [x, axis, keepdims] + [x.astype(dtype), axis, keepdims] for axis in [-1, 0, 1, 2, None] for keepdims in [False, True] + for dtype in ['float32', 'float64', 'int32', 'int64'] ] for lis_test in lis_tests: self.dygraph_single_test_median(lis_test) @@ -164,6 +170,20 @@ def test_median_exception(self): self.assertRaises(ValueError, paddle.median, x, 2, False, 'max') self.assertRaises(ValueError, paddle.median, paddle.to_tensor([])) + def test_nan(self): + paddle.disable_static() + x = np.array( + [[1, 2, 3, float('nan')], [1, 2, 3, 4], [float('nan'), 1, 2, 3]] + ) + lis_tests = [ + [x.astype(dtype), axis, keepdims] + for axis in [-1, 0, 1, None] + for keepdims in [False, True] + for dtype in ['float32', 'float64'] + ] + for lis_test in lis_tests: + self.dygraph_single_test_median(lis_test) + class TestMedianMin(unittest.TestCase): def static_single_test_median(self, lis_test): @@ -195,9 +215,10 @@ def test_median_static(self): l = 2 x = np.arange(h * w * l).reshape([h, w, l]).astype("float32") lis_tests = [ - [x, axis, keepdims] + [x.astype(dtype), axis, keepdims] for axis in [-1, 0, 1, 2] for keepdims in [False, True] + for dtype in ['float32', 'float64', 'int32', 'int64'] ] for lis_test in lis_tests: self.static_single_test_median(lis_test) @@ -209,9 +230,10 @@ def test_median_dygraph(self): l = 2 x = np.arange(h * w * l).reshape([h, w, l]).astype("float32") lis_tests = [ - [x, axis, keepdims] + [x.astype(dtype), axis, keepdims] for axis in [-1, 0, 1, 2] for keepdims in [False, True] + for dtype in ['float32', 'float64', 'int32', 'int64'] ] for lis_test in lis_tests: self.dygraph_single_test_median(lis_test) @@ -230,6 +252,20 @@ def test_index_odd_case(self): np.testing.assert_allclose(out.numpy(), [4.0, 14.0, 24.0]) np.testing.assert_equal(index.numpy(), [4, 4, 4]) + def test_nan(self): + paddle.disable_static() + x = np.array( + [[1, 2, 3, float('nan')], [1, 2, 3, 4], [float('nan'), 1, 2, 3]] + ) + lis_tests = [ + [x.astype(dtype), axis, keepdims] + for axis in [-1, 0, 1, None] + for keepdims in [False, True] + for dtype in ['float32', 'float64'] + ] + for lis_test in lis_tests: + self.dygraph_single_test_median(lis_test) + if __name__ == '__main__': unittest.main() From e8e93cda608a13428d04ea224018d745c7c33742 Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Mon, 20 May 2024 12:00:43 +0800 Subject: [PATCH 02/11] fix median min dtype --- python/paddle/tensor/stat.py | 2 +- test/legacy_test/test_median.py | 43 +++++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/python/paddle/tensor/stat.py b/python/paddle/tensor/stat.py index 0a8365c3ec12e2..5142cbf94a6d2a 100644 --- a/python/paddle/tensor/stat.py +++ b/python/paddle/tensor/stat.py @@ -391,7 +391,7 @@ def median(x, axis=None, keepdim=False, mode='avg', name=None): Compute the median along the specified axis. Args: - x (Tensor): The input Tensor, it's data type can be float32, float64, int32, int64. + x (Tensor): The input Tensor, it's data type can be float16, float32, float64, int32, int64. axis (int, optional): The axis along which to perform median calculations ``axis`` should be int. ``axis`` should be in range [-D, D), where D is the dimensions of ``x`` . If ``axis`` is less than 0, it works the same way as :math:`axis + D`. diff --git a/test/legacy_test/test_median.py b/test/legacy_test/test_median.py index f13b8e5f84a341..0d5272fbb3eac6 100644 --- a/test/legacy_test/test_median.py +++ b/test/legacy_test/test_median.py @@ -18,6 +18,7 @@ import numpy as np import paddle +from paddle.base import core from paddle.pir_utils import test_with_pir_api DELTA = 1e-6 @@ -184,6 +185,26 @@ def test_nan(self): for lis_test in lis_tests: self.dygraph_single_test_median(lis_test) + @unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_float16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA and do not support float16", + ) + def test_float16(self): + paddle.disable_static(core.CUDAPlace(0)) + x = np.array( + [[1, 2, 3, float('nan')], [1, 2, 3, 4], [float('nan'), 1, 2, 3]] + ).astype('float16') + lis_tests = [ + [axis, keepdims] + for axis in [-1, 0, 1, None] + for keepdims in [False, True] + ] + for axis, keepdims in lis_tests: + res_np = np.median(x, axis=axis, keepdims=keepdims) + res_pd = paddle.median(paddle.to_tensor(x), axis, keepdims) + self.check_numpy_res(res_pd.numpy(False), res_np.astype('float64')) + class TestMedianMin(unittest.TestCase): def static_single_test_median(self, lis_test): @@ -266,6 +287,28 @@ def test_nan(self): for lis_test in lis_tests: self.dygraph_single_test_median(lis_test) + @unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_float16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA and do not support float16", + ) + def test_float16(self): + paddle.disable_static(core.CUDAPlace(0)) + x = np.array( + [[1, 2, 3, float('nan')], [1, 2, 3, 4], [float('nan'), 1, 2, 3]] + ).astype('float16') + lis_tests = [ + [axis, keepdims] + for axis in [-1, 0, 1, None] + for keepdims in [False, True] + ] + for axis, keepdims in lis_tests: + res_np = np_medain_min_axis(x, axis=axis, keepdims=keepdims) + res_pd, _ = paddle.median( + paddle.to_tensor(x), axis, keepdims, mode='min' + ) + np.testing.assert_allclose(res_pd.numpy(False), res_np) + if __name__ == '__main__': unittest.main() From df584125f90479ed75bc6e95f3ea2e4fde5aad53 Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Mon, 20 May 2024 13:42:59 +0800 Subject: [PATCH 03/11] fix test --- test/legacy_test/test_median.py | 34 +++++++++++++++++++++++++-------- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/test/legacy_test/test_median.py b/test/legacy_test/test_median.py index 0d5272fbb3eac6..8b7708768a0dc2 100644 --- a/test/legacy_test/test_median.py +++ b/test/legacy_test/test_median.py @@ -29,9 +29,11 @@ def np_medain_min(data, keepdims=False): data_flat = data.flatten() data_cnt = len(data_flat) - data_flat[np.isnan(data_flat)] = np.inf + if data.dtype != 'int32' and data.dtype != 'int64': + data_flat[np.isnan(data_flat)] = np.inf data_sort = np.sort(data_flat) - data_sort[np.isinf(data_sort)] = np.nan + if data.dtype != 'int32' and data.dtype != 'int64': + data_sort[np.isinf(data_sort)] = np.nan if data_cnt % 2: is_odd = False @@ -75,9 +77,11 @@ def np_medain_min_axis(data, axis=None, keepdims=False): shape=data_flat.shape[:-1], fill_value=data_flat.shape[-1] ) - data_flat[np.isnan(data_flat)] = np.inf + if data.dtype != 'int32' and data.dtype != 'int64': + data_flat[np.isnan(data_flat)] = np.inf data_sort = np.sort(data_flat, axis=-1) - data_sort[np.isinf(data_sort)] = np.nan + if data.dtype != 'int32' and data.dtype != 'int64': + data_sort[np.isinf(data_sort)] = np.nan is_odd = data_cnt % 2 @@ -108,8 +112,17 @@ def np_medain_min_axis(data, axis=None, keepdims=False): class TestMedianAvg(unittest.TestCase): def check_numpy_res(self, np1, np2): self.assertEqual(np1.shape, np2.shape) + np1_isnan = np.isnan(np1) + np2_isnan = np.isnan(np2) + nan_mismatch = np.sum( + (np1_isnan.astype('int32') - np2_isnan.astype('int32')) + * (np1_isnan.astype('int32') - np2_isnan.astype('int32')) + ) + self.assertEqual(nan_mismatch, 0) + np1[np.isnan(np1)] = 0.0 + np2[np.isnan(np2)] = 0.0 mismatch = np.sum((np1 - np2) * (np1 - np2)) - self.assertAlmostEqual(mismatch, 0, DELTA) + self.assertAlmostEqual(mismatch, 0, delta=DELTA) def static_single_test_median(self, lis_test): paddle.enable_static() @@ -224,9 +237,14 @@ def static_single_test_median(self, lis_test): def dygraph_single_test_median(self, lis_test): x, axis, keepdims = lis_test res_np = np_medain_min_axis(x, axis=axis, keepdims=keepdims) - res_pd, _ = paddle.median( - paddle.to_tensor(x), axis, keepdims, mode='min' - ) + if axis is None: + res_pd = paddle.median( + paddle.to_tensor(x), axis, keepdims, mode='min' + ) + else: + res_pd, _ = paddle.median( + paddle.to_tensor(x), axis, keepdims, mode='min' + ) np.testing.assert_allclose(res_pd.numpy(False), res_np) @test_with_pir_api From ac87e05364b160c7986e7bfeba25e05553472694 Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Mon, 20 May 2024 17:07:04 +0800 Subject: [PATCH 04/11] fix test --- python/paddle/tensor/stat.py | 2 +- test/legacy_test/test_median.py | 19 +++++++++++-------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/python/paddle/tensor/stat.py b/python/paddle/tensor/stat.py index 5142cbf94a6d2a..c2f23b5c21e11c 100644 --- a/python/paddle/tensor/stat.py +++ b/python/paddle/tensor/stat.py @@ -570,7 +570,7 @@ def median(x, axis=None, keepdim=False, mode='avg', name=None): paddle.cast(paddle.isnan(x), dtype=x.dtype) * x, axis=axis, keepdim=True, - ) + ).astype(x.dtype) if need_idx: # replace index using the first nan value's index on axis for out_idx x_isnan = paddle.isnan(x) diff --git a/test/legacy_test/test_median.py b/test/legacy_test/test_median.py index 8b7708768a0dc2..fb9c3645cd1a89 100644 --- a/test/legacy_test/test_median.py +++ b/test/legacy_test/test_median.py @@ -48,9 +48,7 @@ def np_medain_min(data, keepdims=False): if keepdims: new_shape = [1] * len(shape) np_res = np_res.reshape(new_shape) - return np_res + np.sum( - np.isnan(data).astype(data.dtype) * data, axis=-1, keepdims=keepdims - ) + return np_res + np.sum(np.isnan(data).astype(data.dtype) * data) def np_medain_min_axis(data, axis=None, keepdims=False): @@ -119,8 +117,8 @@ def check_numpy_res(self, np1, np2): * (np1_isnan.astype('int32') - np2_isnan.astype('int32')) ) self.assertEqual(nan_mismatch, 0) - np1[np.isnan(np1)] = 0.0 - np2[np.isnan(np2)] = 0.0 + np1 = np.where(np.isnan(np1), 0.0, np1) + np2 = np.where(np.isnan(np2), 0.0, np2) mismatch = np.sum((np1 - np2) * (np1 - np2)) self.assertAlmostEqual(mismatch, 0, delta=DELTA) @@ -322,9 +320,14 @@ def test_float16(self): ] for axis, keepdims in lis_tests: res_np = np_medain_min_axis(x, axis=axis, keepdims=keepdims) - res_pd, _ = paddle.median( - paddle.to_tensor(x), axis, keepdims, mode='min' - ) + if axis is None: + res_pd = paddle.median( + paddle.to_tensor(x), axis, keepdims, mode='min' + ) + else: + res_pd, _ = paddle.median( + paddle.to_tensor(x), axis, keepdims, mode='min' + ) np.testing.assert_allclose(res_pd.numpy(False), res_np) From 658bd7c5ddeaf609c066a5c872c24f35986bbed9 Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Tue, 21 May 2024 10:31:35 +0800 Subject: [PATCH 05/11] fix idx calculation branch --- python/paddle/tensor/stat.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/python/paddle/tensor/stat.py b/python/paddle/tensor/stat.py index c2f23b5c21e11c..2ef43c8a3df7bb 100644 --- a/python/paddle/tensor/stat.py +++ b/python/paddle/tensor/stat.py @@ -573,12 +573,18 @@ def median(x, axis=None, keepdim=False, mode='avg', name=None): ).astype(x.dtype) if need_idx: # replace index using the first nan value's index on axis for out_idx - x_isnan = paddle.isnan(x) - contain_nan, nan_index = paddle.topk( - x_isnan.astype("int64"), k=1, axis=axis + # topk is not stable on cpu device, use argsort instead + x_isnan = paddle.isnan(x).astype("int64") + x_all_zero = paddle.zeros_like(x_isnan) + index_along_axis = paddle.argsort( + x_all_zero, axis=axis, stable=True ) + nan_index = paddle.sum( + index_along_axis * x_isnan, axis=axis, keepdim=True + ) + nan_index_mask = paddle.sum(x_isnan, axis=axis, keepdim=True) out_idx = ( - out_idx * (~contain_nan.astype("bool")).astype("int64") + out_idx * (~nan_index_mask.astype("bool")).astype("int64") + nan_index ) From 24e9936084f57ffd3b206afbc4b6a90a4bd4bc82 Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Tue, 21 May 2024 11:01:15 +0800 Subject: [PATCH 06/11] fix code example --- python/paddle/tensor/stat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/tensor/stat.py b/python/paddle/tensor/stat.py index 2ef43c8a3df7bb..75062118fe5c94 100644 --- a/python/paddle/tensor/stat.py +++ b/python/paddle/tensor/stat.py @@ -484,7 +484,7 @@ def median(x, axis=None, keepdim=False, mode='avg', name=None): Tensor(shape=[3, 1], dtype=int64, place=Place(cpu), stop_gradient=True, [[3], [1], - [0]])) + [0]]) """ if not isinstance(x, (Variable, paddle.pir.Value)): raise TypeError("In median, the input x should be a Tensor.") From 21461784d9590ea44d49af00b810b5f8be519fc4 Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Tue, 21 May 2024 12:19:02 +0800 Subject: [PATCH 07/11] fix code example --- python/paddle/tensor/stat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/tensor/stat.py b/python/paddle/tensor/stat.py index 75062118fe5c94..02ce0ec4990697 100644 --- a/python/paddle/tensor/stat.py +++ b/python/paddle/tensor/stat.py @@ -468,7 +468,7 @@ def median(x, axis=None, keepdim=False, mode='avg', name=None): >>> x = paddle.to_tensor(np.array([[1,2,3,float('nan')],[1,2,3,4],[float('nan'),1,2,3]]) >>> y6 = paddle.median(x, axis=-1, keepdim=True) - >>> print(76) + >>> print(y6) Tensor(shape=[3, 1], dtype=float64, place=Place(cpu), stop_gradient=True, [[nan ], [2.50000000], From 0cb37fef9315e3bb32ce27eebe583a4b2bfcb2ad Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Fri, 24 May 2024 13:46:19 +0800 Subject: [PATCH 08/11] update --- python/paddle/tensor/stat.py | 21 +++++++++++++-------- test/legacy_test/test_median.py | 31 ++++++++++++++++++++++++++++++- 2 files changed, 43 insertions(+), 9 deletions(-) diff --git a/python/paddle/tensor/stat.py b/python/paddle/tensor/stat.py index 02ce0ec4990697..8d2b4e1d534ff8 100644 --- a/python/paddle/tensor/stat.py +++ b/python/paddle/tensor/stat.py @@ -465,7 +465,7 @@ def median(x, axis=None, keepdim=False, mode='avg', name=None): [1, 1, 1]) >>> # cases containing nan values - >>> x = paddle.to_tensor(np.array([[1,2,3,float('nan')],[1,2,3,4],[float('nan'),1,2,3]]) + >>> x = paddle.to_tensor(np.array([[1,float('nan'),3,float('nan')],[1,2,3,4],[float('nan'),1,2,3]])) >>> y6 = paddle.median(x, axis=-1, keepdim=True) >>> print(y6) @@ -482,7 +482,7 @@ def median(x, axis=None, keepdim=False, mode='avg', name=None): [nan]]) >>> print(median_indices) Tensor(shape=[3, 1], dtype=int64, place=Place(cpu), stop_gradient=True, - [[3], + [[1], [1], [0]]) """ @@ -522,13 +522,13 @@ def median(x, axis=None, keepdim=False, mode='avg', name=None): sz = x.shape[axis] kth = sz >> 1 tensor_topk, idx = paddle.topk(x, kth + 1, axis=axis, largest=False) - dtype = ( - 'float64' - if x.dtype - in [core.VarDesc.VarType.FP64, paddle.base.core.DataType.FLOAT64] - else 'float32' - ) if mode == 'avg': + dtype = ( + 'float64' + if x.dtype + in [core.VarDesc.VarType.FP64, paddle.base.core.DataType.FLOAT64] + else 'float32' + ) if sz & 1 == 0: out_tensor = paddle.slice( tensor_topk, axes=[axis], starts=[kth - 1], ends=[kth] @@ -579,6 +579,11 @@ def median(x, axis=None, keepdim=False, mode='avg', name=None): index_along_axis = paddle.argsort( x_all_zero, axis=axis, stable=True ) + + # find the index of the leading one in x_isnan + cumsum = x_isnan.cumsum(axis=axis) + x_isnan = x_isnan * paddle.where(cumsum > 1, 0, 1) + nan_index = paddle.sum( index_along_axis * x_isnan, axis=axis, keepdim=True ) diff --git a/test/legacy_test/test_median.py b/test/legacy_test/test_median.py index fb9c3645cd1a89..bb88776eddc10f 100644 --- a/test/legacy_test/test_median.py +++ b/test/legacy_test/test_median.py @@ -215,6 +215,19 @@ def test_float16(self): res_np = np.median(x, axis=axis, keepdims=keepdims) res_pd = paddle.median(paddle.to_tensor(x), axis, keepdims) self.check_numpy_res(res_pd.numpy(False), res_np.astype('float64')) + np.testing.assert_equal(res_pd.numpy(False).dtype, np.float32) + + def test_output_dtype(self): + supported_dypes = ['float32', 'float64', 'int32', 'int64'] + for inp_dtype in supported_dypes: + x = np.random.randint(low=-100, high=100, size=[2, 4, 5]).astype( + inp_dtype + ) + res = paddle.median(paddle.to_tensor(x), mode='avg') + if inp_dtype == 'float64': + np.testing.assert_equal(res.numpy().dtype, np.float64) + else: + np.testing.assert_equal(res.numpy().dtype, np.float32) class TestMedianMin(unittest.TestCase): @@ -292,7 +305,13 @@ def test_index_odd_case(self): def test_nan(self): paddle.disable_static() x = np.array( - [[1, 2, 3, float('nan')], [1, 2, 3, 4], [float('nan'), 1, 2, 3]] + [ + [1, 2, 3, float('nan')], + [1, 2, 3, 4], + [float('nan'), 1, 2, 3], + [1, float('nan'), 3, float('nan')], + [float('nan'), float('nan'), 3, float('nan')], + ] ) lis_tests = [ [x.astype(dtype), axis, keepdims] @@ -329,6 +348,16 @@ def test_float16(self): paddle.to_tensor(x), axis, keepdims, mode='min' ) np.testing.assert_allclose(res_pd.numpy(False), res_np) + np.testing.assert_equal(res_pd.numpy(False).dtype, np.float16) + + def test_output_dtype(self): + supported_dypes = ['float32', 'float64', 'int32', 'int64'] + for inp_dtype in supported_dypes: + x = np.random.randint(low=-100, high=100, size=[2, 4, 5]).astype( + inp_dtype + ) + res = paddle.median(paddle.to_tensor(x), mode='min') + np.testing.assert_equal(res.numpy().dtype, np.dtype(inp_dtype)) if __name__ == '__main__': From e72dda9373af746a805f01377418ec3be5c56275 Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Fri, 24 May 2024 14:48:52 +0800 Subject: [PATCH 09/11] update docs --- python/paddle/tensor/stat.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/paddle/tensor/stat.py b/python/paddle/tensor/stat.py index 8d2b4e1d534ff8..ceafd593d7fd2a 100644 --- a/python/paddle/tensor/stat.py +++ b/python/paddle/tensor/stat.py @@ -423,6 +423,7 @@ def median(x, axis=None, keepdim=False, mode='avg', name=None): .. code-block:: python >>> import paddle + >>> import numpy as np >>> x = paddle.arange(12).reshape([3, 4]) >>> print(x) From 7d6c4e7dfc8930e395cbac106d02f0a505135d04 Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Tue, 28 May 2024 09:57:19 +0800 Subject: [PATCH 10/11] update --- python/paddle/tensor/stat.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/python/paddle/tensor/stat.py b/python/paddle/tensor/stat.py index ceafd593d7fd2a..d18163116f169b 100644 --- a/python/paddle/tensor/stat.py +++ b/python/paddle/tensor/stat.py @@ -589,10 +589,7 @@ def median(x, axis=None, keepdim=False, mode='avg', name=None): index_along_axis * x_isnan, axis=axis, keepdim=True ) nan_index_mask = paddle.sum(x_isnan, axis=axis, keepdim=True) - out_idx = ( - out_idx * (~nan_index_mask.astype("bool")).astype("int64") - + nan_index - ) + out_idx = out_idx * paddle.logical_not(nan_index_mask) + nan_index if is_flatten: if keepdim: From b5333183c5d7cf97c87912a2f25462b3a916fe57 Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Tue, 28 May 2024 12:17:43 +0800 Subject: [PATCH 11/11] update --- python/paddle/tensor/stat.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/paddle/tensor/stat.py b/python/paddle/tensor/stat.py index d18163116f169b..b4dfdbe717f3eb 100644 --- a/python/paddle/tensor/stat.py +++ b/python/paddle/tensor/stat.py @@ -589,7 +589,10 @@ def median(x, axis=None, keepdim=False, mode='avg', name=None): index_along_axis * x_isnan, axis=axis, keepdim=True ) nan_index_mask = paddle.sum(x_isnan, axis=axis, keepdim=True) - out_idx = out_idx * paddle.logical_not(nan_index_mask) + nan_index + out_idx = ( + out_idx * paddle.logical_not(nan_index_mask).astype('int64') + + nan_index + ) if is_flatten: if keepdim: