Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 50 additions & 6 deletions python/paddle/tensor/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ def median(x, axis=None, keepdim=False, mode='avg', name=None):
Compute the median along the specified axis.

Args:
x (Tensor): The input Tensor, it's data type can be bool, float16, float32, float64, int32, int64.
x (Tensor): The input Tensor, it's data type can be float16, float32, float64, int32, int64.
axis (int, optional): The axis along which to perform median calculations ``axis`` should be int.
``axis`` should be in range [-D, D), where D is the dimensions of ``x`` .
If ``axis`` is less than 0, it works the same way as :math:`axis + D`.
Expand Down Expand Up @@ -463,6 +463,28 @@ def median(x, axis=None, keepdim=False, mode='avg', name=None):
>>> print(median_indices)
Tensor(shape=[3], dtype=int64, place=Place(cpu), stop_gradient=True,
[1, 1, 1])

>>> # cases containing nan values
>>> x = paddle.to_tensor(np.array([[1,2,3,float('nan')],[1,2,3,4],[float('nan'),1,2,3]])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个x不是float64的Tensor吗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,这里是想加一下如果输入里有nan的例子


>>> y6 = paddle.median(x, axis=-1, keepdim=True)
>>> print(y6)
Tensor(shape=[3, 1], dtype=float64, place=Place(cpu), stop_gradient=True,
[[nan ],
[2.50000000],
[nan ]])

>>> median_value, median_indices = paddle.median(x, axis=1, keepdim=True, mode='min')
>>> print(median_value)
Tensor(shape=[3, 1], dtype=float64, place=Place(cpu), stop_gradient=True,
[[nan],
[2. ],
[nan]])
>>> print(median_indices)
Tensor(shape=[3, 1], dtype=int64, place=Place(cpu), stop_gradient=True,
[[3],
[1],
[0]])
"""
if not isinstance(x, (Variable, paddle.pir.Value)):
raise TypeError("In median, the input x should be a Tensor.")
Expand Down Expand Up @@ -521,6 +543,11 @@ def median(x, axis=None, keepdim=False, mode='avg', name=None):
),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

第525行dtype的设置,这个设置其实不太合理,仅放在avg分支下吧,不影响min的分支

dtype=dtype,
)
out_tensor = out_tensor + paddle.sum(
paddle.cast(paddle.isnan(x), dtype=dtype) * x.astype(dtype),
axis=axis,
keepdim=True,
)
else: # mode == 'min'
if sz & 1 == 0:
out_tensor = paddle.slice(
Expand All @@ -538,12 +565,29 @@ def median(x, axis=None, keepdim=False, mode='avg', name=None):
out_idx = paddle.slice(
idx, axes=[axis], starts=[kth], ends=[kth + 1]
)
# if contain nan on axis, return nan for that axis
out_tensor = out_tensor + paddle.sum(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

最后这一个 astype(x.dtype) 不需要吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

paddle.sum在输入是int32时输出会变成int64,最后这个 astype(x.dtype)是针对int32这种情况
https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/sum_cn.html

paddle.cast(paddle.isnan(x), dtype=x.dtype) * x,
axis=axis,
keepdim=True,
).astype(x.dtype)
if need_idx:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

关于nan的问题就先不大改吧,按之前的逻辑来,主要是适配dtype的影响

Copy link
Contributor Author

@NKNaN NKNaN May 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,这里的 if need_idx 分支是对输入有nan且需要输出index的情况处理,需要删掉吗?如果要删掉的话就是输入有nan的时候不输出index这样?目前torch的median输入有nan的时候会输出index,之前添加min分支的时候没有考虑这个情况,所以这里想补一下

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

那就补上吧

# replace index using the first nan value's index on axis for out_idx
# topk is not stable on cpu device, use argsort instead
x_isnan = paddle.isnan(x).astype("int64")
x_all_zero = paddle.zeros_like(x_isnan)
index_along_axis = paddle.argsort(
x_all_zero, axis=axis, stable=True
)
nan_index = paddle.sum(
Copy link
Contributor

@zhwesky2010 zhwesky2010 May 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果有多个nan,取paddle.sum好像也会出问题吧。多个nan应该按第一个nan的坐标来计算

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image
改了一下,多个nan按第一个nan的坐标计算

index_along_axis * x_isnan, axis=axis, keepdim=True
)
nan_index_mask = paddle.sum(x_isnan, axis=axis, keepdim=True)
out_idx = (
Copy link
Contributor

@zhwesky2010 zhwesky2010 May 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以简化下写法:

out_idx = out_idx * paddle.logical_not(nan_index_mask) + nan_index

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

out_idx * (~nan_index_mask.astype("bool")).astype("int64")
+ nan_index
)

out_tensor = out_tensor + paddle.sum(
paddle.cast(paddle.isnan(x), dtype=dtype) * x.astype(dtype),
axis=axis,
keepdim=True,
)
if is_flatten:
if keepdim:
out_tensor = out_tensor.reshape([1] * dims)
Expand Down
134 changes: 117 additions & 17 deletions test/legacy_test/test_median.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import numpy as np

import paddle
from paddle.base import core
from paddle.pir_utils import test_with_pir_api

DELTA = 1e-6
Expand All @@ -28,9 +29,11 @@ def np_medain_min(data, keepdims=False):
data_flat = data.flatten()
data_cnt = len(data_flat)

data_flat[np.isnan(data_flat)] = np.inf
if data.dtype != 'int32' and data.dtype != 'int64':
data_flat[np.isnan(data_flat)] = np.inf
data_sort = np.sort(data_flat)
data_sort[np.isinf(data_sort)] = np.nan
if data.dtype != 'int32' and data.dtype != 'int64':
data_sort[np.isinf(data_sort)] = np.nan

if data_cnt % 2:
is_odd = False
Expand All @@ -44,9 +47,8 @@ def np_medain_min(data, keepdims=False):
np_res = data_sort[i]
if keepdims:
new_shape = [1] * len(shape)
return np_res.reshape(new_shape)
else:
return np_res
np_res = np_res.reshape(new_shape)
return np_res + np.sum(np.isnan(data).astype(data.dtype) * data)


def np_medain_min_axis(data, axis=None, keepdims=False):
Expand All @@ -73,9 +75,11 @@ def np_medain_min_axis(data, axis=None, keepdims=False):
shape=data_flat.shape[:-1], fill_value=data_flat.shape[-1]
)

data_flat[np.isnan(data_flat)] = np.inf
if data.dtype != 'int32' and data.dtype != 'int64':
data_flat[np.isnan(data_flat)] = np.inf
data_sort = np.sort(data_flat, axis=-1)
data_sort[np.isinf(data_sort)] = np.nan
if data.dtype != 'int32' and data.dtype != 'int64':
data_sort[np.isinf(data_sort)] = np.nan

is_odd = data_cnt % 2

Expand All @@ -95,16 +99,28 @@ def np_medain_min_axis(data, axis=None, keepdims=False):
if keepdims:
shape = list(data.shape)
shape[axis] = 1
return np.reshape(np_res, shape)
np_res = np.reshape(np_res, shape)
else:
return np.reshape(np_res, reshape[:-1])
np_res = np.reshape(np_res, reshape[:-1])
return np_res + np.sum(
np.isnan(data).astype(data.dtype) * data, axis=axis, keepdims=keepdims
)


class TestMedianAvg(unittest.TestCase):
def check_numpy_res(self, np1, np2):
self.assertEqual(np1.shape, np2.shape)
np1_isnan = np.isnan(np1)
np2_isnan = np.isnan(np2)
nan_mismatch = np.sum(
(np1_isnan.astype('int32') - np2_isnan.astype('int32'))
* (np1_isnan.astype('int32') - np2_isnan.astype('int32'))
)
self.assertEqual(nan_mismatch, 0)
np1 = np.where(np.isnan(np1), 0.0, np1)
np2 = np.where(np.isnan(np2), 0.0, np2)
mismatch = np.sum((np1 - np2) * (np1 - np2))
self.assertAlmostEqual(mismatch, 0, DELTA)
self.assertAlmostEqual(mismatch, 0, delta=DELTA)

def static_single_test_median(self, lis_test):
paddle.enable_static()
Expand Down Expand Up @@ -133,9 +149,10 @@ def test_median_static(self):
l = 2
x = np.arange(h * w * l).reshape([h, w, l])
lis_tests = [
[x, axis, keepdims]
[x.astype(dtype), axis, keepdims]
for axis in [-1, 0, 1, 2, None]
for keepdims in [False, True]
for dtype in ['float32', 'float64', 'int32', 'int64']
]
for lis_test in lis_tests:
self.static_single_test_median(lis_test)
Expand All @@ -147,9 +164,10 @@ def test_median_dygraph(self):
l = 2
x = np.arange(h * w * l).reshape([h, w, l])
lis_tests = [
[x, axis, keepdims]
[x.astype(dtype), axis, keepdims]
for axis in [-1, 0, 1, 2, None]
for keepdims in [False, True]
for dtype in ['float32', 'float64', 'int32', 'int64']
]
for lis_test in lis_tests:
self.dygraph_single_test_median(lis_test)
Expand All @@ -164,6 +182,40 @@ def test_median_exception(self):
self.assertRaises(ValueError, paddle.median, x, 2, False, 'max')
self.assertRaises(ValueError, paddle.median, paddle.to_tensor([]))

def test_nan(self):
Copy link
Contributor

@zhwesky2010 zhwesky2010 May 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

单独专门测一下int32/int64吧,关于nan的问题先不用深究了,保持之前的逻辑就行

paddle.disable_static()
x = np.array(
[[1, 2, 3, float('nan')], [1, 2, 3, 4], [float('nan'), 1, 2, 3]]
)
lis_tests = [
[x.astype(dtype), axis, keepdims]
for axis in [-1, 0, 1, None]
for keepdims in [False, True]
for dtype in ['float32', 'float64']
]
for lis_test in lis_tests:
self.dygraph_single_test_median(lis_test)

@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_float16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA and do not support float16",
)
def test_float16(self):
paddle.disable_static(core.CUDAPlace(0))
x = np.array(
[[1, 2, 3, float('nan')], [1, 2, 3, 4], [float('nan'), 1, 2, 3]]
).astype('float16')
lis_tests = [
[axis, keepdims]
for axis in [-1, 0, 1, None]
for keepdims in [False, True]
]
for axis, keepdims in lis_tests:
res_np = np.median(x, axis=axis, keepdims=keepdims)
res_pd = paddle.median(paddle.to_tensor(x), axis, keepdims)
self.check_numpy_res(res_pd.numpy(False), res_np.astype('float64'))


class TestMedianMin(unittest.TestCase):
def static_single_test_median(self, lis_test):
Expand All @@ -183,9 +235,14 @@ def static_single_test_median(self, lis_test):
def dygraph_single_test_median(self, lis_test):
x, axis, keepdims = lis_test
res_np = np_medain_min_axis(x, axis=axis, keepdims=keepdims)
res_pd, _ = paddle.median(
paddle.to_tensor(x), axis, keepdims, mode='min'
)
if axis is None:
res_pd = paddle.median(
paddle.to_tensor(x), axis, keepdims, mode='min'
)
else:
res_pd, _ = paddle.median(
paddle.to_tensor(x), axis, keepdims, mode='min'
)
np.testing.assert_allclose(res_pd.numpy(False), res_np)

@test_with_pir_api
Expand All @@ -195,9 +252,10 @@ def test_median_static(self):
l = 2
x = np.arange(h * w * l).reshape([h, w, l]).astype("float32")
lis_tests = [
[x, axis, keepdims]
[x.astype(dtype), axis, keepdims]
for axis in [-1, 0, 1, 2]
for keepdims in [False, True]
for dtype in ['float32', 'float64', 'int32', 'int64']
]
for lis_test in lis_tests:
self.static_single_test_median(lis_test)
Expand All @@ -209,9 +267,10 @@ def test_median_dygraph(self):
l = 2
x = np.arange(h * w * l).reshape([h, w, l]).astype("float32")
lis_tests = [
[x, axis, keepdims]
[x.astype(dtype), axis, keepdims]
for axis in [-1, 0, 1, 2]
for keepdims in [False, True]
for dtype in ['float32', 'float64', 'int32', 'int64']
]
for lis_test in lis_tests:
self.dygraph_single_test_median(lis_test)
Expand All @@ -230,6 +289,47 @@ def test_index_odd_case(self):
np.testing.assert_allclose(out.numpy(), [4.0, 14.0, 24.0])
np.testing.assert_equal(index.numpy(), [4, 4, 4])

def test_nan(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不能运行的case不是这个int32/int64的case吗,这个和nan的关系是?另外这里也没有看到有测int32/int64

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

int32/int64添加在上面 test_median_static 和 test_median_dygraph 中测了。

int32/int64不能运行的是因为我之前添加 min 分支的时候加在了处理 nan 的这部分之前,在没有nan的情况下int32/int64输入在这里就会出错,out_tensor 是 int32/int64 类型而后面的 sum 是cast成了float64
image

现在的修改是让 min 和 avg 分支分别处理 nan 的情况:avg 保持之前的处理逻辑,输入是float32时输出是float32,其他情况输出是float64;min 在这个地方改了一下,cast的dtype改成了x.dtype,让输入输出的数据类型保持一致,同时如果要输出index的话也加了对index的相应处理。

paddle.disable_static()
x = np.array(
[[1, 2, 3, float('nan')], [1, 2, 3, 4], [float('nan'), 1, 2, 3]]
)
lis_tests = [
[x.astype(dtype), axis, keepdims]
for axis in [-1, 0, 1, None]
for keepdims in [False, True]
for dtype in ['float32', 'float64']
]
for lis_test in lis_tests:
self.dygraph_single_test_median(lis_test)

@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_float16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA and do not support float16",
)
def test_float16(self):
paddle.disable_static(core.CUDAPlace(0))
x = np.array(
[[1, 2, 3, float('nan')], [1, 2, 3, 4], [float('nan'), 1, 2, 3]]
).astype('float16')
lis_tests = [
[axis, keepdims]
for axis in [-1, 0, 1, None]
for keepdims in [False, True]
]
for axis, keepdims in lis_tests:
res_np = np_medain_min_axis(x, axis=axis, keepdims=keepdims)
if axis is None:
res_pd = paddle.median(
paddle.to_tensor(x), axis, keepdims, mode='min'
)
else:
res_pd, _ = paddle.median(
paddle.to_tensor(x), axis, keepdims, mode='min'
)
np.testing.assert_allclose(res_pd.numpy(False), res_np)


if __name__ == '__main__':
unittest.main()