Skip to content

Conversation

@Enigmatisms
Copy link
Contributor

PR Category

Operator Mechanism

PR Types

Bug fixes

Description

修复了大 Tensor 情况下的。argsort 算子出错。argsort 算子出错原因:
cub::DeviceSegmentedRadixSort 目前不支持超过 INT_MAX 个元素的sort,所以对于大tensor,需要一个kernel call拆分机制,将一次 sort 拆分为多次。比如下面这个例子:

paddle.argsort(Tensor([228170138, 10],"int64"), axis=1, stable=True, )

由于 Tensor numel 为 2281701380 大于 INT_MAX,直接用 DeviceSegmentedRadixSort::SortPairs 计算 temp buffer 大小将会出现溢出。目前的解决方案:每次参与 sort 的元素不超过 2^30 个(不选择 2^31 - 1 是为了避免过大的 temp buffer,显存压力大),按照 sorting 维度计算 batch,拆分为多个 cub kernel。比如上述例子将会被拆分为三次计算:

  • [2^30, 10] radix sort 第一次
  • [2^30, 10] radix sort 第二次
  • [2281701380 - 2^31, 10] 第一次

关于 cub::DeviceSegmentedRadixSort 的 API 支持范围,见:

经过性能测试,此 kernel 目前平均运行时间约为 torch 的 1.25倍,部分较好的 shape 下(比如[5, 456340276] axis = -1)可达到将近50%的计算时间。

Pcard-89620

@paddle-bot
Copy link

paddle-bot bot commented May 14, 2025

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@lshpku lshpku merged commit 6ef8d19 into PaddlePaddle:develop May 15, 2025
50 checks passed
wanghuancoder pushed a commit to wanghuancoder/Paddle that referenced this pull request May 27, 2025
* [PHI] Fixed argsort big tensor bug

* [PHI] Fixed shape mismatch problem.
wanghuancoder added a commit that referenced this pull request Jun 3, 2025
* refine forrange (#72360)

* refine forrange

* refine forrange

* reduce support big tensor (#71970)

* reduce support big tensor

* [PHI] Fix gridDim limit for reduce kernel (#72507)

* [API] isclose support bigtensor (#72516)

* isclose support bigtensor

* refine

* [API] isnan isinf isfinite support bigtensor (#72517)

* isnan isinf isfinite support bigtensor

* refine

* [PHI] Fix cum kernel for big tensor (#72562)

* [PHI] Preliminary fix for elementwise broadcast int32 shape overflow (#72584)

* [PHI] Align linalg.solve kernel with torch (#72608)

* Update strided copy kernel (#72662)

* [PHI] Fix grid sample kernel for big tensor (#72628)

* [PHI] Fix argsort big tensor bug (#72712)

* [PHI] Fixed argsort big tensor bug

* [PHI] Fixed shape mismatch problem.

* [PHI] Fix contiguous kernel for big tensor (#72705)

* [PHI] Fix flatten and split kernel for big tensor (#72634)

* [PHI] Fix out-of-bound issue of paddle.take_along_axis (#72757)

* [PHI] fix paddle.diag with big tensor (#72638)

* [API] fix paddle.cross with big tensor (#72652)

* [PHI] Fix paddle.where api for big tensor (#72717)

* [PHI] Fix bincount kernel for big tensor (#72706)

* fix bincount kernel for big tensor

* use HostAlloc to alloc memory

* add cpu test case

* [PHI] Fix full_like kernel for big tensor (#72831)

* [API] Fix int overflow and float16 support for paddle.frac (#72815)

* [PHI] Align paddle.inner with torch in matmul logic (#72843)

* [PHI] Fix paddle.var & paddle.std float16 overflow (#72650)

* [PHI] Fix logsumexp precision problem (#72681)

* [PHI] Debug for logsumexp, bug source found

* [PHI] Removed GetNumBlocks func to get correct logsumexp

* [PHI] Removed redundant debug VLOG

* [PHI] Elegant grid bounded solution

* [Accuracy diff No.55-56、76-77] Fix accuracy diff for var&std API (#72879)

* [Accuracy diff No.21] Fix accuracy diff for heaviside API (#72894)

---------

Co-authored-by: Shuhao Liang <[email protected]>
Co-authored-by: Qianyue He <[email protected]>
Co-authored-by: Lei Ding <[email protected]>
Co-authored-by: ggggxm <[email protected]>
Co-authored-by: xkkkkkk23 <[email protected]>
Co-authored-by: Zx <[email protected]>
Co-authored-by: huangjiyi <[email protected]>
Co-authored-by: ooo oo <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants