Skip to content

[CINN] Fixed gather_nd incorrect logic for negative inputs.#73940

Merged
lshpku merged 2 commits intoPaddlePaddle:developfrom
Enigmatisms:index_put_grad
Jul 15, 2025
Merged

[CINN] Fixed gather_nd incorrect logic for negative inputs.#73940
lshpku merged 2 commits intoPaddlePaddle:developfrom
Enigmatisms:index_put_grad

Conversation

@Enigmatisms
Copy link
Contributor

@Enigmatisms Enigmatisms commented Jul 9, 2025

PR Category

CINN

PR Types

Bug fixes

Description

本PR修复了 gather_nd 算子的CINN lowering逻辑。原始逻辑不会对负数输入进行处理,导致计算错误。本PR:

  • 增加了对负数的处理(基于ir::Select
  • 删除了无效的 non-symbolic strategy
  • 增加了对应的单测 test_cinn_gather_nd.py

性能测试结果

配置 原始 本PR
test_llama forward.py 约38.3us 约45.3us
test_llama inference.py 约38.8us 约45.6us
gather_nd(x * y, indices) + z (256, 128) + (100, 2) 3.35ms 3.32ms
gather_nd(x * y, indices) + z (4096, 2048) + (131072, 2) 41.3ms 41.5ms

注:前两个测试例子为 paddle 的单测,内部有融合了 gather_nd 算子的kernel,本实验统计的是与 gather_nd 相关kernel 的平均运行时间。

部分kernel的性能好像降了很多?经过NCU的分析发现:对应的kernel 实际上throughput提高了,且bottleneck环节delay大幅下降(比如lg throttle等等),但执行的SASS指令数量显著提升。

进一步对比了下面几个实现方法对应的速度:

  • select(本PR,单个kernel): ~17us
  • 位操作(64bit,具体:((-int(index > 0)) & shape) + index): ~17.5 us
  • 位操作(32bit): ~14.5us
  • mod操作((index + shape) % shape): ~12us,与修改前一致(甚至快一丢丢)。这个处理方法原本是非常理想的(使用 mod 操作实现 CINN 内部逻辑时,甚至不用对 shape 内可能的 min、max操作 operand 进行 recast),但现有的表达式简化逻辑暂不支持存在负数的情况((a+N)%N会被简化为a % N,在a为负数,N为常数时,这个是一个不成立的简化)。故本优化被暂时放弃,因为上述select实现的PR引入的性能下降反应在整个pass上很小(受影响kernel的时间占比小于1%)。

Pcard-89620

@Enigmatisms Enigmatisms force-pushed the index_put_grad branch 2 times, most recently from 9d0aa9b to 4fcb05f Compare July 11, 2025 07:32
@lshpku lshpku merged commit e1842d4 into PaddlePaddle:develop Jul 15, 2025
105 of 110 checks passed
@Enigmatisms Enigmatisms deleted the index_put_grad branch August 29, 2025 05:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants