diff --git a/python/paddle/tensor/stat.py b/python/paddle/tensor/stat.py index 376b0ce1d421e5..4dafa608cc2948 100644 --- a/python/paddle/tensor/stat.py +++ b/python/paddle/tensor/stat.py @@ -628,7 +628,23 @@ def median( axis += dims sz = x.shape[axis] kth = sz >> 1 - tensor_topk, idx = paddle.topk(x, kth + 1, axis=axis, largest=False) + # Use `sort` when: + # 1. The axis is not the last dimension (memory non-contiguous) + # 2. The axis size exceeds 10000 (heuristic threshold for performance crossover) + # Rationale: + # - `paddle.topk` in non-contiguous dimensions has O(N*k) complexity (k=n/2 for median → O(n²)). in paddle/phi/kernels/gpu/top_k_kernel.cu + # - `paddle.sort` has guaranteed O(n log n) complexity regardless of axis + use_sort = (axis != dims - 1) and (sz > 10000) + if use_sort: + sorted_x = paddle.sort(x, axis=axis, stable=True) + tensor_topk = paddle.slice( + sorted_x, axes=[axis], starts=[0], ends=[kth + 1] + ) + if need_idx: + idx = paddle.argsort(x, axis=axis, stable=True) + idx = paddle.slice(idx, axes=[axis], starts=[0], ends=[kth + 1]) + else: + tensor_topk, idx = paddle.topk(x, kth + 1, axis=axis, largest=False) if mode == 'avg': dtype = ( 'float64' diff --git a/test/legacy_test/test_median.py b/test/legacy_test/test_median.py index 3cab9133359af8..77a9145f9205c7 100644 --- a/test/legacy_test/test_median.py +++ b/test/legacy_test/test_median.py @@ -401,5 +401,23 @@ def test_median_dygraph(self): self.dygraph_single_test_median([x, 1, False]) +class TestMedianSort(unittest.TestCase): + def dygraph_single_test_median(self, lis_test): + x, axis, keepdims = lis_test + res_np = np.median(x, axis=axis, keepdims=keepdims) + x_pd = paddle.to_tensor(x) + x_pd.stop_gradient = False + res_pd = paddle.median(x_pd, axis, keepdims) + np.testing.assert_allclose(res_pd.numpy(), res_np) + + def test_median_dygraph(self): + paddle.disable_static() + h = 2 + w = 20000 + l = 2 + x = np.arange(h * w * l).reshape([h, w, l]) + self.dygraph_single_test_median([x, 1, False]) + + if __name__ == '__main__': unittest.main()