From 88261ff4cd6656f2f29a895b4cc5d857d77b3399 Mon Sep 17 00:00:00 2001 From: Asthestarsfalll <1186454801@qq.com> Date: Thu, 21 Mar 2024 23:09:25 +0800 Subject: [PATCH 1/2] API Improvement: quantile and nanquantile --- python/paddle/tensor/stat.py | 141 +++++++---- .../test_quantile_and_nanquantile.py | 221 +++++++++++++++++- 2 files changed, 315 insertions(+), 47 deletions(-) diff --git a/python/paddle/tensor/stat.py b/python/paddle/tensor/stat.py index 0d931e3f9caaf3..c4ab8b4ed2ad9f 100644 --- a/python/paddle/tensor/stat.py +++ b/python/paddle/tensor/stat.py @@ -558,14 +558,16 @@ def median(x, axis=None, keepdim=False, mode='avg', name=None): return out_tensor -def _compute_quantile(x, q, axis=None, keepdim=False, ignore_nan=False): +def _compute_quantile( + x, q, axis=None, keepdim=False, interpolation="linear", ignore_nan=False +): """ Compute the quantile of the input along the specified axis. Args: x (Tensor): The input Tensor, it's data type can be float32, float64, int32, int64. - q (int|float|list): The q for calculate quantile, which should be in range [0, 1]. If q is a list, - each q will be calculated and the first dimension of output is same to the number of ``q`` . + q (int|float|list|Tensor): The q for calculate quantile, which should be in range [0, 1]. If q is a list, + a 1-D Tensor or a 0-D Tensor, each q will be calculated and the first dimension of output is same to the number of ``q`` . axis (int|list, optional): The axis along which to calculate quantile. ``axis`` should be int or list of int. ``axis`` should be in range [-D, D), where D is the dimensions of ``x`` . If ``axis`` is less than 0, it works the same way as :math:`axis + D`. @@ -576,6 +578,8 @@ def _compute_quantile(x, q, axis=None, keepdim=False, ignore_nan=False): the output Tensor is the same as ``x`` except in the reduced dimensions(it is of size 1 in this case). Otherwise, the shape of the output Tensor is squeezed in ``axis`` . Default is False. + interpolation (str, optional): The interpolation method to use + when the desired quantile falls between two data points. Default is linear. ignore_nan: (bool, optional): Whether to ignore NaN of input Tensor. If ``ignore_nan`` is True, it will calculate nanquantile. Otherwise it will calculate quantile. Default is False. @@ -594,9 +598,33 @@ def _compute_quantile(x, q, axis=None, keepdim=False, ignore_nan=False): elif isinstance(q, (list, tuple)): if len(q) <= 0: raise ValueError("q should not be empty") + elif isinstance(q, Variable): + if len(q.shape) > 1: + raise ValueError("q should be a 0-D tensor or a 1-D tensor") + if len(q.shape) == 0: + q = [q] else: - raise TypeError("Type of q should be int, float, list or tuple.") + raise TypeError( + "Type of q should be int, float, list or tuple, or tensor" + ) + for q_num in q: + if not in_dynamic_or_pir_mode() and isinstance(q_num, Variable): + break + if q_num < 0 or q_num > 1: + raise ValueError("q should be in range [0, 1]") + if interpolation not in [ + "linear", + "lower", + "higher", + "nearest", + "midpoint", + ]: + raise ValueError( + "interpolation must be one of 'linear', 'lower', 'higher', 'nearest' or 'midpoint', but got {}".format( + interpolation + ) + ) # Validate axis dims = len(x.shape) out_shape = list(x.shape) @@ -637,21 +665,16 @@ def _compute_quantile(x, q, axis=None, keepdim=False, ignore_nan=False): out_shape[axis] = 1 mask = x.isnan() - valid_counts = mask.logical_not().sum( - axis=axis, keepdim=True, dtype='float64' - ) + valid_counts = mask.logical_not().sum(axis=axis, keepdim=True) indices = [] for q_num in q: - if q_num < 0 or q_num > 1: - raise ValueError("q should be in range [0, 1]") if in_dynamic_or_pir_mode(): - q_num = paddle.to_tensor(q_num, dtype='float64') + q_num = paddle.to_tensor(q_num, dtype=x.dtype) if ignore_nan: indices.append(q_num * (valid_counts - 1)) else: - # TODO: Use paddle.index_fill instead of where index = q_num * (valid_counts - 1) last_index = x.shape[axis] - 1 nums = paddle.full_like(index, fill_value=last_index) @@ -660,47 +683,63 @@ def _compute_quantile(x, q, axis=None, keepdim=False, ignore_nan=False): sorted_tensor = paddle.sort(x, axis) - outputs = [] + def _compute_index(index): + if interpolation == "nearest": + idx = paddle.round(index).astype(paddle.int32) + return paddle.take_along_axis(sorted_tensor, idx, axis=axis) - # TODO(chenjianye): replace the for-loop to directly take elements. - for index in indices: - indices_below = paddle.floor(index).astype('int32') - indices_upper = paddle.ceil(index).astype('int32') - tensor_upper = paddle.take_along_axis( - sorted_tensor, indices_upper, axis=axis - ) + indices_below = paddle.floor(index).astype(paddle.int32) tensor_below = paddle.take_along_axis( sorted_tensor, indices_below, axis=axis ) - weights = index - indices_below.astype('float64') - out = paddle.lerp( - tensor_below.astype('float64'), - tensor_upper.astype('float64'), + if interpolation == "lower": + return tensor_below + + indices_upper = paddle.ceil(index).astype(paddle.int32) + tensor_upper = paddle.take_along_axis( + sorted_tensor, indices_upper, axis=axis + ) + if interpolation == "higher": + return tensor_upper + + if interpolation == "midpoint": + return (tensor_upper + tensor_below) / 2 + + weights = (index - indices_below).astype(x.dtype) + return paddle.lerp( + tensor_below.astype(x.dtype), + tensor_upper.astype(x.dtype), weights, ) + + outputs = [] + + # TODO(chenjianye): replace the for-loop to directly take elements. + for index in indices: + out = _compute_index(index) if not keepdim: out = paddle.squeeze(out, axis=axis) else: out = out.reshape(out_shape) outputs.append(out) - if len(q) > 1: + if len(outputs) > 1: outputs = paddle.stack(outputs, 0) else: outputs = outputs[0] - + # return outputs.astype(x.dtype) return outputs -def quantile(x, q, axis=None, keepdim=False): +def quantile(x, q, axis=None, keepdim=False, interpolation="linear"): """ Compute the quantile of the input along the specified axis. If any values in a reduced row are NaN, then the quantiles for that reduction will be NaN. Args: x (Tensor): The input Tensor, it's data type can be float32, float64, int32, int64. - q (int|float|list): The q for calculate quantile, which should be in range [0, 1]. If q is a list, - each q will be calculated and the first dimension of output is same to the number of ``q`` . + q (int|float|list|Tensor): The q for calculate quantile, which should be in range [0, 1]. If q is a list, + a 1-D Tensor or a 0-D Tensor, each q will be calculated and the first dimension of output is same to the number of ``q`` . axis (int|list, optional): The axis along which to calculate quantile. ``axis`` should be int or list of int. ``axis`` should be in range [-D, D), where D is the dimensions of ``x`` . If ``axis`` is less than 0, it works the same way as :math:`axis + D`. @@ -711,6 +750,8 @@ def quantile(x, q, axis=None, keepdim=False): the output Tensor is the same as ``x`` except in the reduced dimensions(it is of size 1 in this case). Otherwise, the shape of the output Tensor is squeezed in ``axis`` . Default is False. + interpolation (str, optional): The interpolation method to use + when the desired quantile falls between two data points. Default is linear. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. @@ -733,42 +774,49 @@ def quantile(x, q, axis=None, keepdim=False): >>> y1 = paddle.quantile(y, q=0.5, axis=[0, 1]) >>> print(y1) - Tensor(shape=[], dtype=float64, place=Place(cpu), stop_gradient=True, + Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, 3.50000000) >>> y2 = paddle.quantile(y, q=0.5, axis=1) >>> print(y2) - Tensor(shape=[4], dtype=float64, place=Place(cpu), stop_gradient=True, + Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True, [0.50000000, 2.50000000, 4.50000000, 6.50000000]) >>> y3 = paddle.quantile(y, q=[0.3, 0.5], axis=0) >>> print(y3) - Tensor(shape=[2, 2], dtype=float64, place=Place(cpu), stop_gradient=True, + Tensor(shape=[2, 2], dtype=float32, place=Place(cpu), stop_gradient=True, [[1.80000000, 2.80000000], [3. , 4. ]]) >>> y[0,0] = float("nan") >>> y4 = paddle.quantile(y, q=0.8, axis=1, keepdim=True) >>> print(y4) - Tensor(shape=[4, 1], dtype=float64, place=Place(cpu), stop_gradient=True, + Tensor(shape=[4, 1], dtype=float32, place=Place(cpu), stop_gradient=True, [[nan ], [2.80000000], [4.80000000], [6.80000000]]) """ - return _compute_quantile(x, q, axis=axis, keepdim=keepdim, ignore_nan=False) + return _compute_quantile( + x, + q, + axis=axis, + keepdim=keepdim, + interpolation=interpolation, + ignore_nan=False, + ) -def nanquantile(x, q, axis=None, keepdim=False): +def nanquantile(x, q, axis=None, keepdim=False, interpolation="linear"): """ Compute the quantile of the input as if NaN values in input did not exist. If all values in a reduced row are NaN, then the quantiles for that reduction will be NaN. Args: x (Tensor): The input Tensor, it's data type can be float32, float64, int32, int64. - q (int|float|list): The q for calculate quantile, which should be in range [0, 1]. If q is a list, - each q will be calculated and the first dimension of output is same to the number of ``q`` . + q (int|float|list|Tensor): The q for calculate quantile, which should be in range [0, 1]. If q is a list or + a 1-D Tensor, each q will be calculated and the first dimension of output is same to the number of ``q`` . axis (int|list, optional): The axis along which to calculate quantile. ``axis`` should be int or list of int. ``axis`` should be in range [-D, D), where D is the dimensions of ``x`` . If ``axis`` is less than 0, it works the same way as :math:`axis + D`. @@ -779,6 +827,8 @@ def nanquantile(x, q, axis=None, keepdim=False): the output Tensor is the same as ``x`` except in the reduced dimensions(it is of size 1 in this case). Otherwise, the shape of the output Tensor is squeezed in ``axis`` . Default is False. + interpolation (str, optional): The interpolation method to use + when the desired quantile falls between two data points. Default is linear. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. @@ -799,32 +849,39 @@ def nanquantile(x, q, axis=None, keepdim=False): >>> y1 = paddle.nanquantile(x, q=0.5, axis=[0, 1]) >>> print(y1) - Tensor(shape=[], dtype=float64, place=Place(cpu), stop_gradient=True, + Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, 5.) >>> y2 = paddle.nanquantile(x, q=0.5, axis=1) >>> print(y2) - Tensor(shape=[2], dtype=float64, place=Place(cpu), stop_gradient=True, + Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True, [2.50000000, 7. ]) >>> y3 = paddle.nanquantile(x, q=[0.3, 0.5], axis=0) >>> print(y3) - Tensor(shape=[2, 5], dtype=float64, place=Place(cpu), stop_gradient=True, + Tensor(shape=[2, 5], dtype=float32, place=Place(cpu), stop_gradient=True, [[5. , 2.50000000, 3.50000000, 4.50000000, 5.50000000], [5. , 3.50000000, 4.50000000, 5.50000000, 6.50000000]]) >>> y4 = paddle.nanquantile(x, q=0.8, axis=1, keepdim=True) >>> print(y4) - Tensor(shape=[2, 1], dtype=float64, place=Place(cpu), stop_gradient=True, + Tensor(shape=[2, 1], dtype=float32, place=Place(cpu), stop_gradient=True, [[3.40000000], [8.20000000]]) >>> nan = paddle.full(shape=[2, 3], fill_value=float("nan")) >>> y5 = paddle.nanquantile(nan, q=0.8, axis=1, keepdim=True) >>> print(y5) - Tensor(shape=[2, 1], dtype=float64, place=Place(cpu), stop_gradient=True, + Tensor(shape=[2, 1], dtype=float32, place=Place(cpu), stop_gradient=True, [[nan], [nan]]) """ - return _compute_quantile(x, q, axis=axis, keepdim=keepdim, ignore_nan=True) + return _compute_quantile( + x, + q, + axis=axis, + keepdim=keepdim, + interpolation=interpolation, + ignore_nan=True, + ) diff --git a/test/legacy_test/test_quantile_and_nanquantile.py b/test/legacy_test/test_quantile_and_nanquantile.py index 815520ccfff6a2..2bb6df174454f0 100644 --- a/test/legacy_test/test_quantile_and_nanquantile.py +++ b/test/legacy_test/test_quantile_and_nanquantile.py @@ -119,6 +119,89 @@ def test_nanquantile_all_NaN(self): paddle_res.numpy(), np_res, rtol=1e-05, equal_nan=True ) + def test_nanquantile_interpolation(self): + input_data = np.random.randn(2, 3, 4) + input_data[0, 1, 1] = np.nan + x = paddle.to_tensor(input_data) + for mode in ["lower", "higher", "midpoint", "nearest"]: + paddle_res = paddle.nanquantile( + x, q=0.35, axis=0, interpolation=mode + ) + np_res = np.nanquantile(input_data, q=0.35, axis=0, method=mode) + np.testing.assert_allclose( + paddle_res.numpy(), np_res, rtol=1e-05, equal_nan=True + ) + + def test_backward(self): + def check_grad(x, q, axis, target_gard, apis=None): + x = np.array(x, dtype="float32") + paddle.disable_static() + for op, _ in apis or API_list: + x_p = paddle.to_tensor(x, dtype="float32", stop_gradient=False) + op(x_p, q, axis).sum().backward() + np.testing.assert_allclose( + x_p.grad.numpy(), + np.array(target_gard, dtype="float32"), + rtol=1e-05, + equal_nan=True, + ) + paddle.enable_static() + opt = paddle.optimizer.SGD(learning_rate=0.01) + for op, _ in apis or API_list: + s_p = paddle.static.Program() + m_p = paddle.static.Program() + with paddle.static.program_guard(m_p, s_p): + x_p = paddle.static.data( + name="x", + shape=x.shape, + dtype=paddle.float32, + ) + x_p.stop_gradient = False + q_p = paddle.static.data( + name="q", + shape=[len(q)] if isinstance(q, list) else [], + dtype=paddle.float32, + ) + loss = op(x_p, q_p, axis).sum() + opt.minimize(loss) + exe = paddle.static.Executor() + exe.run(paddle.static.default_startup_program()) + o = exe.run( + paddle.static.default_main_program(), + feed={"x": x, "q": np.array(q, dtype="float32")}, + fetch_list=["x@GRAD"], + )[0] + np.testing.assert_allclose( + o, + np.array(target_gard, dtype="float32"), + rtol=1e-05, + equal_nan=True, + ) + paddle.disable_static() + + check_grad([1, 2, 3], 0.5, 0, [0, 1, 0]) + check_grad( + [1, 2, 3, 4] * 2, [0.55, 0.7], 0, [0, 0, 0.95, 0, 0, 0.15, 0.9, 0] + ) + check_grad( + [[1, 2, 3], [4, 5, 6]], + [0.3, 0.7], + 1, + [[0.4, 1.2, 0.4], [0.4, 1.2, 0.4]], + ) + # quantile + check_grad( + [1, float("nan"), 3], 0.5, 0, [0, 1, 0], [(paddle.quantile, None)] + ) + # nanquantile + check_grad( + [1, float("nan"), 3], + 0.5, + 0, + [0.5, 0, 0.5], + [(paddle.nanquantile, None)], + ) + class TestMuitlpleQ(unittest.TestCase): """ @@ -150,6 +233,24 @@ def test_quantile_multiple_axis_keepdim(self): ) np.testing.assert_allclose(paddle_res.numpy(), np_res, rtol=1e-05) + def test_quantile_with_tensor_input(self): + x = paddle.to_tensor(self.input_data) + paddle_res = paddle.quantile( + x, q=paddle.to_tensor([0.1, 0.2]), axis=[1, 2], keepdim=True + ) + np_res = np.quantile( + self.input_data, q=[0.1, 0.2], axis=[1, 2], keepdims=True + ) + np.testing.assert_allclose(paddle_res.numpy(), np_res, rtol=1e-05) + + def test_quantile_with_zero_dim_tensor_input(self): + x = paddle.to_tensor(self.input_data) + paddle_res = paddle.quantile( + x, q=paddle.to_tensor(0.1), axis=[1, 2], keepdim=True + ) + np_res = np.quantile(self.input_data, q=0.1, axis=[1, 2], keepdims=True) + np.testing.assert_allclose(paddle_res.numpy(), np_res, rtol=1e-05) + class TestError(unittest.TestCase): """ @@ -210,6 +311,26 @@ def test_axis_value_error_2(): self.assertRaises(ValueError, test_axis_value_error_2) + # Test error when q is not a 1-D tensor + def test_tensor_input_1(): + paddle_res = paddle.quantile( + self.x, q=paddle.randn((2, 3)), axis=[1, -10] + ) + + self.assertRaises(ValueError, test_tensor_input_1) + + def test_type_q(): + paddle_res = paddle.quantile(self.x, q={1}, axis=[1, -10]) + + self.assertRaises(TypeError, test_type_q) + + def test_interpolation(): + paddle_res = paddle.quantile( + self.x, q={1}, axis=[1, -10], interpolation=" " + ) + + self.assertRaises(TypeError, test_interpolation) + class TestQuantileRuntime(unittest.TestCase): """ @@ -255,9 +376,9 @@ def test_static(self): ) results = func(x, q=0.5, axis=1) - np_input_data = self.input_data.astype('float32') + np_input_data = self.input_data.astype("float32") results_fp64 = func(x_fp64, q=0.5, axis=1) - np_input_data_fp64 = self.input_data.astype('float64') + np_input_data_fp64 = self.input_data.astype("float64") exe = paddle.static.Executor(device) paddle_res, paddle_res_fp64 = exe.run( @@ -267,11 +388,101 @@ def test_static(self): ) np_res = res_func(np_input_data, q=0.5, axis=1) np_res_fp64 = res_func(np_input_data_fp64, q=0.5, axis=1) - self.assertTrue( - np.allclose(paddle_res, np_res) - and np.allclose(paddle_res_fp64, np_res_fp64) + np.testing.assert_allclose(paddle_res, np_res, rtol=1e-05) + np.testing.assert_allclose( + paddle_res_fp64, np_res_fp64, rtol=1e-05 ) + def test_static_tensor(self): + paddle.enable_static() + for func, res_func in API_list: + s_p = paddle.static.Program() + m_p = paddle.static.Program() + with paddle.static.program_guard(m_p, s_p): + for device in self.devices: + x = paddle.static.data( + name="x", + shape=self.input_data.shape, + dtype=paddle.float32, + ) + q = paddle.static.data( + name="q", shape=(3,), dtype=paddle.float32 + ) + x_fp64 = paddle.static.data( + name="x_fp64", + shape=self.input_data.shape, + dtype=paddle.float64, + ) + + results = func(x, q=q, axis=1) + np_input_data = self.input_data.astype("float32") + results_fp64 = func(x_fp64, q=q, axis=1) + np_input_data_fp64 = self.input_data.astype("float64") + q_data = np.array([0.5, 0.5, 0.5]).astype("float32") + + exe = paddle.static.Executor(device) + paddle_res, paddle_res_fp64 = exe.run( + paddle.static.default_main_program(), + feed={ + "x": np_input_data, + "x_fp64": np_input_data_fp64, + "q": q_data, + }, + fetch_list=[results, results_fp64], + ) + np_res = res_func(np_input_data, q=[0.5, 0.5, 0.5], axis=1) + np_res_fp64 = res_func( + np_input_data_fp64, q=[0.5, 0.5, 0.5], axis=1 + ) + np.testing.assert_allclose(paddle_res, np_res, rtol=1e-05) + np.testing.assert_allclose( + paddle_res_fp64, np_res_fp64, rtol=1e-05 + ) + + def test_static_0d_tensor(self): + paddle.enable_static() + for func, res_func in API_list: + for device in self.devices: + s_p = paddle.static.Program() + m_p = paddle.static.Program() + with paddle.static.program_guard(m_p, s_p): + x = paddle.static.data( + name="x", + shape=self.input_data.shape, + dtype=paddle.float32, + ) + q = paddle.static.data( + name="q", shape=[], dtype=paddle.float32 + ) + x_fp64 = paddle.static.data( + name="x_fp64", + shape=self.input_data.shape, + dtype=paddle.float64, + ) + + results = func(x, q=q, axis=1) + np_input_data = self.input_data.astype("float32") + results_fp64 = func(x_fp64, q=q, axis=1) + np_input_data_fp64 = self.input_data.astype("float64") + q_data = np.array(0.3).astype("float32") + + exe = paddle.static.Executor(device) + paddle_res, paddle_res_fp64 = exe.run( + paddle.static.default_main_program(), + feed={ + "x": np_input_data, + "x_fp64": np_input_data_fp64, + "q": q_data, + }, + fetch_list=[results, results_fp64], + ) + np_res = res_func(np_input_data, q=0.3, axis=1) + np_res_fp64 = res_func(np_input_data_fp64, q=0.3, axis=1) + np.testing.assert_allclose(paddle_res, np_res, rtol=1e-05) + np.testing.assert_allclose( + paddle_res_fp64, np_res_fp64, rtol=1e-05 + ) + if __name__ == '__main__': unittest.main() From 6f9f6b2772e41522d4384c520ff0e8693bf4a44b Mon Sep 17 00:00:00 2001 From: Asthestarsfalll <1186454801@qq.com> Date: Fri, 22 Mar 2024 16:10:50 +0800 Subject: [PATCH 2/2] update docstring and add test --- python/paddle/tensor/stat.py | 34 ++++++++++++------- .../test_quantile_and_nanquantile.py | 17 +++++----- 2 files changed, 29 insertions(+), 22 deletions(-) diff --git a/python/paddle/tensor/stat.py b/python/paddle/tensor/stat.py index c4ab8b4ed2ad9f..c88d8fa367e209 100644 --- a/python/paddle/tensor/stat.py +++ b/python/paddle/tensor/stat.py @@ -566,8 +566,9 @@ def _compute_quantile( Args: x (Tensor): The input Tensor, it's data type can be float32, float64, int32, int64. - q (int|float|list|Tensor): The q for calculate quantile, which should be in range [0, 1]. If q is a list, - a 1-D Tensor or a 0-D Tensor, each q will be calculated and the first dimension of output is same to the number of ``q`` . + q (int|float|list|Tensor): The q for calculate quantile, which should be in range [0, 1]. If q is a list or + a 1-D Tensor, each element of q will be calculated and the first dimension of output is same to the number of ``q`` . + If q is a 0-D Tensor, it will be treated as an integer or float. axis (int|list, optional): The axis along which to calculate quantile. ``axis`` should be int or list of int. ``axis`` should be in range [-D, D), where D is the dimensions of ``x`` . If ``axis`` is less than 0, it works the same way as :math:`axis + D`. @@ -579,7 +580,8 @@ def _compute_quantile( dimensions(it is of size 1 in this case). Otherwise, the shape of the output Tensor is squeezed in ``axis`` . Default is False. interpolation (str, optional): The interpolation method to use - when the desired quantile falls between two data points. Default is linear. + when the desired quantile falls between two data points. Must be one of linear, higher, + lower, midpoint and nearest. Default is linear. ignore_nan: (bool, optional): Whether to ignore NaN of input Tensor. If ``ignore_nan`` is True, it will calculate nanquantile. Otherwise it will calculate quantile. Default is False. @@ -608,6 +610,7 @@ def _compute_quantile( "Type of q should be int, float, list or tuple, or tensor" ) for q_num in q: + # we do not validate tensor q in static mode if not in_dynamic_or_pir_mode() and isinstance(q_num, Variable): break if q_num < 0 or q_num > 1: @@ -689,9 +692,11 @@ def _compute_index(index): return paddle.take_along_axis(sorted_tensor, idx, axis=axis) indices_below = paddle.floor(index).astype(paddle.int32) - tensor_below = paddle.take_along_axis( - sorted_tensor, indices_below, axis=axis - ) + if interpolation != "higher": + # avoid unnecessary compute + tensor_below = paddle.take_along_axis( + sorted_tensor, indices_below, axis=axis + ) if interpolation == "lower": return tensor_below @@ -706,6 +711,7 @@ def _compute_index(index): return (tensor_upper + tensor_below) / 2 weights = (index - indices_below).astype(x.dtype) + # "linear" return paddle.lerp( tensor_below.astype(x.dtype), tensor_upper.astype(x.dtype), @@ -738,8 +744,9 @@ def quantile(x, q, axis=None, keepdim=False, interpolation="linear"): Args: x (Tensor): The input Tensor, it's data type can be float32, float64, int32, int64. - q (int|float|list|Tensor): The q for calculate quantile, which should be in range [0, 1]. If q is a list, - a 1-D Tensor or a 0-D Tensor, each q will be calculated and the first dimension of output is same to the number of ``q`` . + q (int|float|list|Tensor): The q for calculate quantile, which should be in range [0, 1]. If q is a list or + a 1-D Tensor, each element of q will be calculated and the first dimension of output is same to the number of ``q`` . + If q is a 0-D Tensor, it will be treated as an integer or float. axis (int|list, optional): The axis along which to calculate quantile. ``axis`` should be int or list of int. ``axis`` should be in range [-D, D), where D is the dimensions of ``x`` . If ``axis`` is less than 0, it works the same way as :math:`axis + D`. @@ -751,13 +758,13 @@ def quantile(x, q, axis=None, keepdim=False, interpolation="linear"): dimensions(it is of size 1 in this case). Otherwise, the shape of the output Tensor is squeezed in ``axis`` . Default is False. interpolation (str, optional): The interpolation method to use - when the desired quantile falls between two data points. Default is linear. + when the desired quantile falls between two data points. Must be one of linear, higher, + lower, midpoint and nearest. Default is linear. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: Tensor, results of quantile along ``axis`` of ``x``. - In order to obtain higher precision, data type of results will be float64. Examples: .. code-block:: python @@ -816,7 +823,8 @@ def nanquantile(x, q, axis=None, keepdim=False, interpolation="linear"): Args: x (Tensor): The input Tensor, it's data type can be float32, float64, int32, int64. q (int|float|list|Tensor): The q for calculate quantile, which should be in range [0, 1]. If q is a list or - a 1-D Tensor, each q will be calculated and the first dimension of output is same to the number of ``q`` . + a 1-D Tensor, each element of q will be calculated and the first dimension of output is same to the number of ``q`` . + If q is a 0-D Tensor, it will be treated as an integer or float. axis (int|list, optional): The axis along which to calculate quantile. ``axis`` should be int or list of int. ``axis`` should be in range [-D, D), where D is the dimensions of ``x`` . If ``axis`` is less than 0, it works the same way as :math:`axis + D`. @@ -828,13 +836,13 @@ def nanquantile(x, q, axis=None, keepdim=False, interpolation="linear"): dimensions(it is of size 1 in this case). Otherwise, the shape of the output Tensor is squeezed in ``axis`` . Default is False. interpolation (str, optional): The interpolation method to use - when the desired quantile falls between two data points. Default is linear. + when the desired quantile falls between two data points. Must be one of linear, higher, + lower, midpoint and nearest. Default is linear. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: Tensor, results of quantile along ``axis`` of ``x``. - In order to obtain higher precision, data type of results will be float64. Examples: .. code-block:: python diff --git a/test/legacy_test/test_quantile_and_nanquantile.py b/test/legacy_test/test_quantile_and_nanquantile.py index 2bb6df174454f0..e28bcd1f569643 100644 --- a/test/legacy_test/test_quantile_and_nanquantile.py +++ b/test/legacy_test/test_quantile_and_nanquantile.py @@ -119,18 +119,17 @@ def test_nanquantile_all_NaN(self): paddle_res.numpy(), np_res, rtol=1e-05, equal_nan=True ) - def test_nanquantile_interpolation(self): + def test_interpolation(self): input_data = np.random.randn(2, 3, 4) input_data[0, 1, 1] = np.nan x = paddle.to_tensor(input_data) - for mode in ["lower", "higher", "midpoint", "nearest"]: - paddle_res = paddle.nanquantile( - x, q=0.35, axis=0, interpolation=mode - ) - np_res = np.nanquantile(input_data, q=0.35, axis=0, method=mode) - np.testing.assert_allclose( - paddle_res.numpy(), np_res, rtol=1e-05, equal_nan=True - ) + for op, ref_op in API_list: + for mode in ["lower", "higher", "midpoint", "nearest"]: + paddle_res = op(x, q=0.35, axis=0, interpolation=mode) + np_res = ref_op(input_data, q=0.35, axis=0, method=mode) + np.testing.assert_allclose( + paddle_res.numpy(), np_res, rtol=1e-05, equal_nan=True + ) def test_backward(self): def check_grad(x, q, axis, target_gard, apis=None):