-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Support npu kernel for eye op #34543
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Thanks for your contribution! |
| @@ -0,0 +1,105 @@ | |||
| # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
copyright改成2021吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
其它没问题了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
paddle/fluid/operators/eye_op_npu.cc
Outdated
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
|
|
||
| #include "paddle/fluid/operators/crop_op.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个头文件应该是 eye_op.h?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
paddle/fluid/operators/eye_op_npu.cc
Outdated
| PADDLE_ENFORCE_EQ( | ||
| num_rows >= 0, true, | ||
| platform::errors::InvalidArgument( | ||
| "The value of Input(num_rows) should be non-negative int.")); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个不需要检查,在EyeOp::InferShape的时候已经检查过了。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
paddle/fluid/operators/eye_op_npu.cc
Outdated
| num_columns >= 0, true, | ||
| platform::errors::InvalidArgument( | ||
| "The value of Input(num_columns) should be non-negative int.")); | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上,这个不需要检查
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
paddle/fluid/operators/eye_op_npu.cc
Outdated
|
|
||
| if (ctx.HasAttr("num_columns")) { | ||
| num_columns = ctx.Attr<int64_t>("num_columns"); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个最好参考eye_op.h通过判断num_columns是否等于-1,避免单测或者python API把这个num_columns设置为-1的情况。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
| self.num_rows = 100 | ||
| self.num_columns = 100 | ||
|
|
||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
少了一个dtype是fp32的单测,另外参考 test_eye_op.py 的 class API_TestTensorEye(unittest.TestCase),增加关于静态图API和动态图API的测试。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| expected_result = np.eye(10, dtype="int32") | ||
| self.assertEqual((result == expected_result).all(), True) | ||
|
|
||
| paddle.disable_static() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里可以写成paddle.disable_static(paddle.NPUPlace(0)) 就可以把动态图的api跑到NPU上了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
| paddle.enable_static() | ||
| self.assertEqual((out.numpy() == expected_result).all(), True) | ||
|
|
||
| paddle.disable_static() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
| True) | ||
| self.assertEqual((out.numpy() == expected_result).all(), True) | ||
|
|
||
| paddle.disable_static() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
| if (ctx.HasAttr("num_columns")) { | ||
| num_columns = ctx.Attr<int64_t>("num_columns"); | ||
| } | ||
| if (num_columns == -1) num_columns = num_rows; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
33-37行合并为以下2行就可以,因为在python端的def eye中输入的num_columns不是None,因此ctx.HasAttr("num_columns")在API调用端永远是true.
auto num_columns = ctx.Attr<int64_t>("num_columns");
if (num_columns == -1) num_columns = num_rows;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
明白了,done
qili93
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
qili93
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM

PR types
New features
PR changes
OPs
Describe
add support npu kernel for eye op
