-
Notifications
You must be signed in to change notification settings - Fork 1.6k
fix some bugs for 0d output and fix some typoes #10282
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
d4ff3ea
18dc1ae
8c93e3d
c849b82
ac17a3a
f0bdd01
d56c459
1272daf
b1e0220
6233729
ddb5d86
6fa79e1
ec5a10c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -39,15 +39,16 @@ bool ReduceOp::CheckShape() const { | |
| auto dims = param_.dim; | ||
| auto x_dims = param_.X->dims(); | ||
| int x_rank = x_dims.size(); | ||
| return x_rank >= 0; | ||
| // dim at least is [0] | ||
| CHECK_GT(dims.size(), 0) | ||
| << "The input dim should be greater than 0. But received the dim = " | ||
| << dims.size(); | ||
| for (int i = 0; i < dims.size(); i++) { | ||
| CHECK(dims[i] <= x_rank && dims[i] + x_rank >= 0) | ||
| << "dims[i] is " << dims[i] << ", x_rank is " << x_rank; | ||
| } | ||
| return true; | ||
| // CHECK_GT(dims.size(), 0) | ||
|
||
| // << "The input dim should be greater than 0. But received the dim = " | ||
| // << dims.size(); | ||
| // for (int i = 0; i < dims.size(); i++) { | ||
| // CHECK(dims[i] <= x_rank && dims[i] + x_rank >= 0) | ||
| // << "dims[i] is " << dims[i] << ", x_rank is " << x_rank; | ||
| // } | ||
| // return true; | ||
| } | ||
|
|
||
| bool ReduceOp::InferShapeImpl() const { | ||
|
|
@@ -58,10 +59,10 @@ bool ReduceOp::InferShapeImpl() const { | |
| bool keep_dim = param_.keep_dim; | ||
|
|
||
| for (int i = 0; i < dims.size(); i++) { | ||
| CHECK(dims[i] <= x_rank && dims[i] + x_rank >= 0) | ||
| << "dims[i] is " << dims[i] << ", x_rank is " << x_rank; | ||
| // CHECK(dims[i] <= x_rank && dims[i] + x_rank >= 0) | ||
|
||
| // << "dims[i] is " << dims[i] << ", x_rank is " << x_rank; | ||
| if (dims[i] < 0) { | ||
| dims[i] = x_rank + dims[i]; | ||
| dims[i] = x_rank + dims[i] >= 0 ? x_rank + dims[i] : 0; | ||
| } | ||
| } | ||
| // recompute reduce_all | ||
|
|
@@ -79,8 +80,9 @@ bool ReduceOp::InferShapeImpl() const { | |
| if (reduce_all) { | ||
| if (keep_dim) | ||
| param_.Out->Resize(std::vector<int64_t>(x_rank, 1)); | ||
| else | ||
| param_.Out->Resize(std::vector<int64_t>({1})); | ||
| else { | ||
| param_.Out->Resize(std::vector<int64_t>({})); | ||
| } | ||
| } else { | ||
| std::vector<int64_t> dims_vector(x_rank, 1); | ||
| for (int i = 0; i < x_rank; i++) dims_vector[i] = x_dims[i]; | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
直接删了就行,不用注释
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已改