[0-size Tensor No.271] Add 0-size Tensor support for paddle.take_along_axis API.#73736
Conversation
|
你的PR提交成功,感谢你对开源项目的贡献! |
|
PR 标题请按照规范 |
好的,已完成修改。 |
Codecov ReportAttention: Patch coverage is
❌ Your patch status has failed because the patch coverage (28.57%) is below the target coverage (90.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## develop #73736 +/- ##
==========================================
Coverage ? 28.57%
==========================================
Files ? 1
Lines ? 7
Branches ? 0
==========================================
Hits ? 2
Misses ? 5
Partials ? 0 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
“根据任务要求,首先尝试使用PaddleAPITest复现BUG。在PaddleAPITest的--accuracy=True精度对比模式下,由于paddle.take_along_axis(arr, index, axis)与torch.take_along_dim(input, indices, dim)的参数名不统一,反复出现TypeError: missing a required argument的参数绑定错误,无法准确定位问题。 |
PR Category
Operator Mechanism
PR Types
Bug fixes
Description
本次提交为完成任务 #72637 中关于
paddle.take_along_axis的部分。修改历程介绍如下:
问题复现与分析:
PaddleAPITest复现BUG。在PaddleAPITest的--accuracy=True精度对比模式下,由于paddle.take_along_axis(arr, index, axis)与torch.take_along_dim(input, indices, dim)的参数名不统一,反复出现TypeError: missing a required argument的参数绑定错误,无法准确定位问题。PaddleAPITest的--paddle_only=True模式后,成功触发了Python层的TypeError: take_along_axis() got an unexpected keyword argument 'index'报错,这暴露了API前后端参数名不一致的问题。unittest的OpTest框架编写单元测试。在未修复Kernel的情况下,成功复现了底层的C++错误:前向修复 (Forward Fix):
a. 定位API: 在
Paddle/python/paddle/tensor/manipulation.py中找到了def take_along_axis(...)的Python定义,其核心实现调用了_C_ops.take_along_axis。b. 定位算子定义: 使用
grep发现,该算子没有独立的.yml文件,其定义位于paddle/phi/ops/yaml/ops.yaml中。c. 检查InferMeta: 根据
ops.yaml的指引,在paddle/phi/infermeta/binary.cc中找到了TakeAlongAxisInferMeta函数。经分析,其out->set_dims(index.dims())逻辑能正确推导0-size Tensor的输出形状,无需修改。d. 修改Kernel:
* 根据
grep结果,定位到CPU Kernel文件为paddle/phi/kernels/cpu/take_along_axis_kernel.cc。* 参照标准修复范式,在
TakeAlongAxisKernel函数开头加入了对0-size情况的保护。核心逻辑是判断index.numel()是否为0,因为输出的形状完全由index决定。* 修复代码如下:
d. 依照以上原则修改CPU、GPU、XPU Kernel
反向修复 (Backward Fix):
4. 添加单测 (Add Unit Test):
在
test/legacy_test/test_take_along_axis_op.py文件中,为彻底解决因父类TestTakeAlongAxisOp的setUp方法无法兼容0-size数据而导致的CI报错(IndexError,ValueError),最终方案是放弃继承,添加了两个全新的、独立的OpTest测试类,分别覆盖两种不同的0-size边界场景。测试场景一:输入
arr为0-size,但index不为0-sizeTestTakeAlongAxis0Size1类进行验证。完整测试代码如下:测试场景二:索引
index为0-size,但arr不为0-sizeTestTakeAlongAxis0Size2类进行验证。完整测试代码如下:feature/fix_take_along_axis_0size分支上运行添加的OpTest单元测试,结果为OK,证明修复成功。--accuracy模式无法使用。在--paddle_only模式下,修复后可顺利通过。pcard-67164