[0-size Tensor No.353] Add 0-size Tensor support for unflatten API.#73986
[0-size Tensor No.353] Add 0-size Tensor support for unflatten API.#73986DanielSun11 merged 14 commits intoPaddlePaddle:developfrom
Conversation
|
请使用pre-commit 规范下代码风格 |
|
/re-run all-failed |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## develop #73986 +/- ##
==========================================
Coverage ? 95.65%
==========================================
Files ? 1
Lines ? 23
Branches ? 0
==========================================
Hits ? 22
Misses ? 1
Partials ? 0 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
/re-run all-failed |
|
/re-run all-failed |
1 similar comment
|
/re-run all-failed |
| if (unk_dim_idx != -1) { | ||
| size_t in_dims_zero_cnt = 0; | ||
| for (size_t i = 0; i < in_dims_vec.size(); ++i) | ||
| if (in_dims_vec[i] == 0) in_dims_zero_cnt++; | ||
| if (shape_zero_cnt == in_dims_zero_cnt) { | ||
| int64_t in_dims_pdt = 1; | ||
| int64_t shape_pdt = 1; | ||
| for (size_t i = 0; i < shape.size(); ++i) | ||
| if (shape[i] != 0 && shape[i] != -1) shape_pdt *= shape[i]; | ||
| for (size_t i = 0; i < in_dims_vec.size(); ++i) | ||
| if (in_dims_vec[i] != 0 && in_dims_vec[i] != -1) | ||
| in_dims_pdt *= in_dims_vec[i]; | ||
| output_shape[unk_dim_idx] = in_dims_pdt / shape_pdt; | ||
| PADDLE_ENFORCE_EQ( | ||
| output_shape[unk_dim_idx] * shape_pdt, | ||
| in_dims_pdt, | ||
| common::errors::InvalidArgument( | ||
| "The 'shape' attribute in ReshapeOp is invalid. " | ||
| "The input tensor X'size must be divisible by known " | ||
| "capacity of 'shape'. " | ||
| "But received X's shape = [%s], " | ||
| "'shape' is [%s].", | ||
| in_dims, | ||
| common::make_ddim(shape))); | ||
| return common::make_ddim(output_shape); | ||
| } else if (shape_zero_cnt > in_dims_zero_cnt) { | ||
| int64_t in_dims_pdt = 1; | ||
| int64_t shape_pdt = 1; | ||
| for (size_t i = 0; i < shape.size(); ++i) | ||
| if (shape[i] != 0 && shape[i] != -1) shape_pdt *= shape[i]; | ||
| for (size_t i = 0; i < in_dims_vec.size(); ++i) | ||
| if (in_dims_vec[i] != 0 && in_dims_vec[i] != -1) | ||
| in_dims_pdt *= in_dims_vec[i]; | ||
| PADDLE_ENFORCE_EQ( | ||
| shape_pdt, | ||
| in_dims_pdt, | ||
| common::errors::InvalidArgument( | ||
| "Provided sizes don't multiply up to the size of dim given " | ||
| "in the input tensor")); | ||
| } | ||
| } | ||
| PADDLE_ENFORCE_EQ(unk_dim_idx, |
There was a problem hiding this comment.
代码逻辑没问题。请添加些注释说明下当前逻辑,避免后期维护困难。另外,符号推导中是否需要同步修改?ValidateShape应该是infermeta中的辅助函数,请检查符号推导中是否存在同样的辅助函数以及unflatten的符号推导是否需要修改
| class TestUnflattenInputZeroSize(TestUnflattenAPI): | ||
| def set_args(self): | ||
| self.x = np.random.rand(4, 0, 16).astype('int16') | ||
| self.axis = 0 | ||
| self.shape = (2, 2) | ||
| self.shape_is_tensor = False | ||
|
|
There was a problem hiding this comment.
单测不足以覆盖infermeta中新增的逻辑,当前单测只能覆盖 shape_zero_cnt == in_dims_zero_cnt 的情况,请尝试添加shape_zero_cnt > in_dims_zero_cnt的单测,shape_zero_cnt > in_dims_zero_cnt时应该会抛出异常,请参考单测中测错误case的写法。
|
/re-run all-failed |
|
/re-run all-failed |
|
/re-run all-failed |
1 similar comment
|
/re-run all-failed |




PR Category
Operator Mechanism
PR Types
Bug fixes
Description
a.错误分析
在PaddleAPITest report/0size_tensor中检索paddle.unflatten的错误日志。
定位至源代码(由于github其它分支的合并,具体行数不为2209行),发现报错函数为/paddle/paddle/phi/infermeta/unary.cc中的ValidateShape,这个函数的作用是计算结果张量的形状。
分析该函数逻辑,发现该函数传入的内容包括2项,即shape与in_dims。其中,in_dims是传入原张量的形状,shape(后记为vshape与paddle.unflatten中的shape参数区分)是将希望展开的形状嵌入原张量的形状。用以下代码的执行为例:
此时传入ValidateShape的in_dims为[4, 6, 2],vshape为[4, 6, -1, 1],其中shape中的值不会被做任何替换。
在执行至错误代码前,会构造一个“std::vector<int64_t> output_shape(shape.size(), 0)”,表示结果张量的形状,并根据一定规则填充vshape中的值(如发现-1个数大于1时报错,因为不定的位置只能有1个)。
错误代码报错的原因在于:in_dims中含0且执行至此处时,变量unk_dim_idx的值不为-1。
其中,unk_dim_idx记录的是shape中-1的位置。会对于这一情况报错的原因在于,后续代码中需要确定-1这一不定的值,确定的方法是用“in_dims中各个元素的乘积”除以“shape中-1外各个元素的乘积”,即源代码中的“output_shape[unk_dim_idx] = in_size / capacity“。
当输出0-size张量时,这一除法显然无法成立,因此选择在vshape中同时含有-1和0时报错。
b.错误解决
对于vshape含有-1且in_dims含有0运行至报错位置时的情况可以分三种情况讨论:
1、 vshape中与in_dims中0的个数相同
出现这种情况说明vshape是由in_dims中一个非0数被替换为含-1的shape构成,将in_dims中各个非0数之积除以vshape中各个非-1与0的乘积即可。
2、 vshape中0的个数少于in_dims中0的个数
出现这种情况说明vshape是由in_dims中的0替换为含-1的shape构成,则此时vshape中-1无法确定具体的值,应当报错,可使用源代码相关步骤,不必进行额外修改。
3、 vshape中0的个数多于in_dims中0的个数
出现这种情况可以分两种情况讨论:
情况一:in_dims中非0数被shape替换,且shape中同时含有-1和0,此时参考torch,应当报错:
情况二:in_dims中的0被shape替换,且shape中同时含有-1和至少2个0,此时-1无法确定具体的值,应当报错,可使用源代码相关步骤,不必进行额外修改。
相关测试:
单元测试结果:
PaddleApiTest测试结果:
pcard-67164