Skip to content

Conversation

@yeliang2258
Copy link
Contributor

@yeliang2258 yeliang2258 commented Aug 2, 2021

PR types

New features

PR changes

OPs

Describe

add support npu kernel for eye op
image

@paddle-bot-old
Copy link

paddle-bot-old bot commented Aug 2, 2021

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@@ -0,0 +1,105 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copyright改成2021吧

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

其它没问题了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/crop_op.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个头文件应该是 eye_op.h?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

PADDLE_ENFORCE_EQ(
num_rows >= 0, true,
platform::errors::InvalidArgument(
"The value of Input(num_rows) should be non-negative int."));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个不需要检查,在EyeOp::InferShape的时候已经检查过了。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

num_columns >= 0, true,
platform::errors::InvalidArgument(
"The value of Input(num_columns) should be non-negative int."));

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,这个不需要检查

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


if (ctx.HasAttr("num_columns")) {
num_columns = ctx.Attr<int64_t>("num_columns");
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个最好参考eye_op.h通过判断num_columns是否等于-1,避免单测或者python API把这个num_columns设置为-1的情况。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

self.num_rows = 100
self.num_columns = 100


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

少了一个dtype是fp32的单测,另外参考 test_eye_op.py 的 class API_TestTensorEye(unittest.TestCase),增加关于静态图API和动态图API的测试。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done
image

expected_result = np.eye(10, dtype="int32")
self.assertEqual((result == expected_result).all(), True)

paddle.disable_static()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里可以写成paddle.disable_static(paddle.NPUPlace(0)) 就可以把动态图的api跑到NPU上了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

paddle.enable_static()
self.assertEqual((out.numpy() == expected_result).all(), True)

paddle.disable_static()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

True)
self.assertEqual((out.numpy() == expected_result).all(), True)

paddle.disable_static()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

if (ctx.HasAttr("num_columns")) {
num_columns = ctx.Attr<int64_t>("num_columns");
}
if (num_columns == -1) num_columns = num_rows;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

33-37行合并为以下2行就可以,因为在python端的def eye中输入的num_columns不是None,因此ctx.HasAttr("num_columns")在API调用端永远是true.
auto num_columns = ctx.Attr<int64_t>("num_columns");
if (num_columns == -1) num_columns = num_rows;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

明白了,done

qili93
qili93 previously approved these changes Aug 5, 2021
Copy link
Contributor

@qili93 qili93 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@qili93 qili93 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@qili93 qili93 merged commit 6e442e6 into PaddlePaddle:develop Aug 6, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants