diff --git a/python/paddle/tensor/stat.py b/python/paddle/tensor/stat.py index 545b6bf21ca9e4..b4dfdbe717f3eb 100644 --- a/python/paddle/tensor/stat.py +++ b/python/paddle/tensor/stat.py @@ -391,7 +391,7 @@ def median(x, axis=None, keepdim=False, mode='avg', name=None): Compute the median along the specified axis. Args: - x (Tensor): The input Tensor, it's data type can be bool, float16, float32, float64, int32, int64. + x (Tensor): The input Tensor, it's data type can be float16, float32, float64, int32, int64. axis (int, optional): The axis along which to perform median calculations ``axis`` should be int. ``axis`` should be in range [-D, D), where D is the dimensions of ``x`` . If ``axis`` is less than 0, it works the same way as :math:`axis + D`. @@ -423,6 +423,7 @@ def median(x, axis=None, keepdim=False, mode='avg', name=None): .. code-block:: python >>> import paddle + >>> import numpy as np >>> x = paddle.arange(12).reshape([3, 4]) >>> print(x) @@ -463,6 +464,28 @@ def median(x, axis=None, keepdim=False, mode='avg', name=None): >>> print(median_indices) Tensor(shape=[3], dtype=int64, place=Place(cpu), stop_gradient=True, [1, 1, 1]) + + >>> # cases containing nan values + >>> x = paddle.to_tensor(np.array([[1,float('nan'),3,float('nan')],[1,2,3,4],[float('nan'),1,2,3]])) + + >>> y6 = paddle.median(x, axis=-1, keepdim=True) + >>> print(y6) + Tensor(shape=[3, 1], dtype=float64, place=Place(cpu), stop_gradient=True, + [[nan ], + [2.50000000], + [nan ]]) + + >>> median_value, median_indices = paddle.median(x, axis=1, keepdim=True, mode='min') + >>> print(median_value) + Tensor(shape=[3, 1], dtype=float64, place=Place(cpu), stop_gradient=True, + [[nan], + [2. ], + [nan]]) + >>> print(median_indices) + Tensor(shape=[3, 1], dtype=int64, place=Place(cpu), stop_gradient=True, + [[1], + [1], + [0]]) """ if not isinstance(x, (Variable, paddle.pir.Value)): raise TypeError("In median, the input x should be a Tensor.") @@ -500,13 +523,13 @@ def median(x, axis=None, keepdim=False, mode='avg', name=None): sz = x.shape[axis] kth = sz >> 1 tensor_topk, idx = paddle.topk(x, kth + 1, axis=axis, largest=False) - dtype = ( - 'float64' - if x.dtype - in [core.VarDesc.VarType.FP64, paddle.base.core.DataType.FLOAT64] - else 'float32' - ) if mode == 'avg': + dtype = ( + 'float64' + if x.dtype + in [core.VarDesc.VarType.FP64, paddle.base.core.DataType.FLOAT64] + else 'float32' + ) if sz & 1 == 0: out_tensor = paddle.slice( tensor_topk, axes=[axis], starts=[kth - 1], ends=[kth] @@ -521,6 +544,11 @@ def median(x, axis=None, keepdim=False, mode='avg', name=None): ), dtype=dtype, ) + out_tensor = out_tensor + paddle.sum( + paddle.cast(paddle.isnan(x), dtype=dtype) * x.astype(dtype), + axis=axis, + keepdim=True, + ) else: # mode == 'min' if sz & 1 == 0: out_tensor = paddle.slice( @@ -538,12 +566,34 @@ def median(x, axis=None, keepdim=False, mode='avg', name=None): out_idx = paddle.slice( idx, axes=[axis], starts=[kth], ends=[kth + 1] ) + # if contain nan on axis, return nan for that axis + out_tensor = out_tensor + paddle.sum( + paddle.cast(paddle.isnan(x), dtype=x.dtype) * x, + axis=axis, + keepdim=True, + ).astype(x.dtype) + if need_idx: + # replace index using the first nan value's index on axis for out_idx + # topk is not stable on cpu device, use argsort instead + x_isnan = paddle.isnan(x).astype("int64") + x_all_zero = paddle.zeros_like(x_isnan) + index_along_axis = paddle.argsort( + x_all_zero, axis=axis, stable=True + ) + + # find the index of the leading one in x_isnan + cumsum = x_isnan.cumsum(axis=axis) + x_isnan = x_isnan * paddle.where(cumsum > 1, 0, 1) + + nan_index = paddle.sum( + index_along_axis * x_isnan, axis=axis, keepdim=True + ) + nan_index_mask = paddle.sum(x_isnan, axis=axis, keepdim=True) + out_idx = ( + out_idx * paddle.logical_not(nan_index_mask).astype('int64') + + nan_index + ) - out_tensor = out_tensor + paddle.sum( - paddle.cast(paddle.isnan(x), dtype=dtype) * x.astype(dtype), - axis=axis, - keepdim=True, - ) if is_flatten: if keepdim: out_tensor = out_tensor.reshape([1] * dims) diff --git a/test/legacy_test/test_median.py b/test/legacy_test/test_median.py index ee38ef57f79c9e..bb88776eddc10f 100644 --- a/test/legacy_test/test_median.py +++ b/test/legacy_test/test_median.py @@ -18,6 +18,7 @@ import numpy as np import paddle +from paddle.base import core from paddle.pir_utils import test_with_pir_api DELTA = 1e-6 @@ -28,9 +29,11 @@ def np_medain_min(data, keepdims=False): data_flat = data.flatten() data_cnt = len(data_flat) - data_flat[np.isnan(data_flat)] = np.inf + if data.dtype != 'int32' and data.dtype != 'int64': + data_flat[np.isnan(data_flat)] = np.inf data_sort = np.sort(data_flat) - data_sort[np.isinf(data_sort)] = np.nan + if data.dtype != 'int32' and data.dtype != 'int64': + data_sort[np.isinf(data_sort)] = np.nan if data_cnt % 2: is_odd = False @@ -44,9 +47,8 @@ def np_medain_min(data, keepdims=False): np_res = data_sort[i] if keepdims: new_shape = [1] * len(shape) - return np_res.reshape(new_shape) - else: - return np_res + np_res = np_res.reshape(new_shape) + return np_res + np.sum(np.isnan(data).astype(data.dtype) * data) def np_medain_min_axis(data, axis=None, keepdims=False): @@ -73,9 +75,11 @@ def np_medain_min_axis(data, axis=None, keepdims=False): shape=data_flat.shape[:-1], fill_value=data_flat.shape[-1] ) - data_flat[np.isnan(data_flat)] = np.inf + if data.dtype != 'int32' and data.dtype != 'int64': + data_flat[np.isnan(data_flat)] = np.inf data_sort = np.sort(data_flat, axis=-1) - data_sort[np.isinf(data_sort)] = np.nan + if data.dtype != 'int32' and data.dtype != 'int64': + data_sort[np.isinf(data_sort)] = np.nan is_odd = data_cnt % 2 @@ -95,16 +99,28 @@ def np_medain_min_axis(data, axis=None, keepdims=False): if keepdims: shape = list(data.shape) shape[axis] = 1 - return np.reshape(np_res, shape) + np_res = np.reshape(np_res, shape) else: - return np.reshape(np_res, reshape[:-1]) + np_res = np.reshape(np_res, reshape[:-1]) + return np_res + np.sum( + np.isnan(data).astype(data.dtype) * data, axis=axis, keepdims=keepdims + ) class TestMedianAvg(unittest.TestCase): def check_numpy_res(self, np1, np2): self.assertEqual(np1.shape, np2.shape) + np1_isnan = np.isnan(np1) + np2_isnan = np.isnan(np2) + nan_mismatch = np.sum( + (np1_isnan.astype('int32') - np2_isnan.astype('int32')) + * (np1_isnan.astype('int32') - np2_isnan.astype('int32')) + ) + self.assertEqual(nan_mismatch, 0) + np1 = np.where(np.isnan(np1), 0.0, np1) + np2 = np.where(np.isnan(np2), 0.0, np2) mismatch = np.sum((np1 - np2) * (np1 - np2)) - self.assertAlmostEqual(mismatch, 0, DELTA) + self.assertAlmostEqual(mismatch, 0, delta=DELTA) def static_single_test_median(self, lis_test): paddle.enable_static() @@ -133,9 +149,10 @@ def test_median_static(self): l = 2 x = np.arange(h * w * l).reshape([h, w, l]) lis_tests = [ - [x, axis, keepdims] + [x.astype(dtype), axis, keepdims] for axis in [-1, 0, 1, 2, None] for keepdims in [False, True] + for dtype in ['float32', 'float64', 'int32', 'int64'] ] for lis_test in lis_tests: self.static_single_test_median(lis_test) @@ -147,9 +164,10 @@ def test_median_dygraph(self): l = 2 x = np.arange(h * w * l).reshape([h, w, l]) lis_tests = [ - [x, axis, keepdims] + [x.astype(dtype), axis, keepdims] for axis in [-1, 0, 1, 2, None] for keepdims in [False, True] + for dtype in ['float32', 'float64', 'int32', 'int64'] ] for lis_test in lis_tests: self.dygraph_single_test_median(lis_test) @@ -164,6 +182,53 @@ def test_median_exception(self): self.assertRaises(ValueError, paddle.median, x, 2, False, 'max') self.assertRaises(ValueError, paddle.median, paddle.to_tensor([])) + def test_nan(self): + paddle.disable_static() + x = np.array( + [[1, 2, 3, float('nan')], [1, 2, 3, 4], [float('nan'), 1, 2, 3]] + ) + lis_tests = [ + [x.astype(dtype), axis, keepdims] + for axis in [-1, 0, 1, None] + for keepdims in [False, True] + for dtype in ['float32', 'float64'] + ] + for lis_test in lis_tests: + self.dygraph_single_test_median(lis_test) + + @unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_float16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA and do not support float16", + ) + def test_float16(self): + paddle.disable_static(core.CUDAPlace(0)) + x = np.array( + [[1, 2, 3, float('nan')], [1, 2, 3, 4], [float('nan'), 1, 2, 3]] + ).astype('float16') + lis_tests = [ + [axis, keepdims] + for axis in [-1, 0, 1, None] + for keepdims in [False, True] + ] + for axis, keepdims in lis_tests: + res_np = np.median(x, axis=axis, keepdims=keepdims) + res_pd = paddle.median(paddle.to_tensor(x), axis, keepdims) + self.check_numpy_res(res_pd.numpy(False), res_np.astype('float64')) + np.testing.assert_equal(res_pd.numpy(False).dtype, np.float32) + + def test_output_dtype(self): + supported_dypes = ['float32', 'float64', 'int32', 'int64'] + for inp_dtype in supported_dypes: + x = np.random.randint(low=-100, high=100, size=[2, 4, 5]).astype( + inp_dtype + ) + res = paddle.median(paddle.to_tensor(x), mode='avg') + if inp_dtype == 'float64': + np.testing.assert_equal(res.numpy().dtype, np.float64) + else: + np.testing.assert_equal(res.numpy().dtype, np.float32) + class TestMedianMin(unittest.TestCase): def static_single_test_median(self, lis_test): @@ -183,9 +248,14 @@ def static_single_test_median(self, lis_test): def dygraph_single_test_median(self, lis_test): x, axis, keepdims = lis_test res_np = np_medain_min_axis(x, axis=axis, keepdims=keepdims) - res_pd, _ = paddle.median( - paddle.to_tensor(x), axis, keepdims, mode='min' - ) + if axis is None: + res_pd = paddle.median( + paddle.to_tensor(x), axis, keepdims, mode='min' + ) + else: + res_pd, _ = paddle.median( + paddle.to_tensor(x), axis, keepdims, mode='min' + ) np.testing.assert_allclose(res_pd.numpy(False), res_np) @test_with_pir_api @@ -195,9 +265,10 @@ def test_median_static(self): l = 2 x = np.arange(h * w * l).reshape([h, w, l]).astype("float32") lis_tests = [ - [x, axis, keepdims] + [x.astype(dtype), axis, keepdims] for axis in [-1, 0, 1, 2] for keepdims in [False, True] + for dtype in ['float32', 'float64', 'int32', 'int64'] ] for lis_test in lis_tests: self.static_single_test_median(lis_test) @@ -209,9 +280,10 @@ def test_median_dygraph(self): l = 2 x = np.arange(h * w * l).reshape([h, w, l]).astype("float32") lis_tests = [ - [x, axis, keepdims] + [x.astype(dtype), axis, keepdims] for axis in [-1, 0, 1, 2] for keepdims in [False, True] + for dtype in ['float32', 'float64', 'int32', 'int64'] ] for lis_test in lis_tests: self.dygraph_single_test_median(lis_test) @@ -230,6 +302,63 @@ def test_index_odd_case(self): np.testing.assert_allclose(out.numpy(), [4.0, 14.0, 24.0]) np.testing.assert_equal(index.numpy(), [4, 4, 4]) + def test_nan(self): + paddle.disable_static() + x = np.array( + [ + [1, 2, 3, float('nan')], + [1, 2, 3, 4], + [float('nan'), 1, 2, 3], + [1, float('nan'), 3, float('nan')], + [float('nan'), float('nan'), 3, float('nan')], + ] + ) + lis_tests = [ + [x.astype(dtype), axis, keepdims] + for axis in [-1, 0, 1, None] + for keepdims in [False, True] + for dtype in ['float32', 'float64'] + ] + for lis_test in lis_tests: + self.dygraph_single_test_median(lis_test) + + @unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_float16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA and do not support float16", + ) + def test_float16(self): + paddle.disable_static(core.CUDAPlace(0)) + x = np.array( + [[1, 2, 3, float('nan')], [1, 2, 3, 4], [float('nan'), 1, 2, 3]] + ).astype('float16') + lis_tests = [ + [axis, keepdims] + for axis in [-1, 0, 1, None] + for keepdims in [False, True] + ] + for axis, keepdims in lis_tests: + res_np = np_medain_min_axis(x, axis=axis, keepdims=keepdims) + if axis is None: + res_pd = paddle.median( + paddle.to_tensor(x), axis, keepdims, mode='min' + ) + else: + res_pd, _ = paddle.median( + paddle.to_tensor(x), axis, keepdims, mode='min' + ) + np.testing.assert_allclose(res_pd.numpy(False), res_np) + np.testing.assert_equal(res_pd.numpy(False).dtype, np.float16) + + def test_output_dtype(self): + supported_dypes = ['float32', 'float64', 'int32', 'int64'] + for inp_dtype in supported_dypes: + x = np.random.randint(low=-100, high=100, size=[2, 4, 5]).astype( + inp_dtype + ) + res = paddle.median(paddle.to_tensor(x), mode='min') + np.testing.assert_equal(res.numpy().dtype, np.dtype(inp_dtype)) + if __name__ == '__main__': unittest.main()