Skip to content

Commit ac3d821

Browse files
authored
[NPU] add npu kernel for equal op (#31393)
* add npu kernel for equal op * refine code * add more ut * update year
1 parent 0310945 commit ac3d821

2 files changed

Lines changed: 143 additions & 0 deletions

File tree

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <algorithm>
16+
#include <string>
17+
#include <vector>
18+
19+
#include "paddle/fluid/framework/op_registry.h"
20+
#include "paddle/fluid/framework/op_version_registry.h"
21+
#include "paddle/fluid/operators/controlflow/compare_op.h"
22+
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
23+
#include "paddle/fluid/operators/npu_op_runner.h"
24+
25+
namespace paddle {
26+
namespace operators {
27+
28+
template <typename T>
29+
class EqualNPUKernel : public framework::OpKernel<T> {
30+
public:
31+
void Compute(const framework::ExecutionContext& ctx) const override {
32+
auto* x = ctx.Input<framework::LoDTensor>("X");
33+
auto* y = ctx.Input<framework::LoDTensor>("Y");
34+
auto* out = ctx.Output<framework::LoDTensor>("Out");
35+
out->mutable_data<bool>(ctx.GetPlace());
36+
37+
auto runner = NpuOpRunner("Equal", {*x, *y}, {*out}, {});
38+
auto stream =
39+
ctx.template device_context<paddle::platform::NPUDeviceContext>()
40+
.stream();
41+
runner.Run(stream);
42+
}
43+
};
44+
45+
} // namespace operators
46+
} // namespace paddle
47+
48+
namespace ops = paddle::operators;
49+
namespace plat = paddle::platform;
50+
51+
REGISTER_OP_NPU_KERNEL(equal, ops::EqualNPUKernel<float>,
52+
ops::EqualNPUKernel<plat::float16>,
53+
ops::EqualNPUKernel<int>);
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
import numpy as np
18+
import unittest
19+
import sys
20+
sys.path.append("..")
21+
from op_test import OpTest
22+
import paddle
23+
import paddle.fluid as fluid
24+
25+
paddle.enable_static()
26+
SEED = 2021
27+
28+
29+
@unittest.skipIf(not paddle.is_compiled_with_npu(),
30+
"core is not compiled with NPU")
31+
class TestEqual(OpTest):
32+
def setUp(self):
33+
self.set_npu()
34+
self.op_type = "equal"
35+
self.place = paddle.NPUPlace(0)
36+
37+
self.init_dtype()
38+
np.random.seed(SEED)
39+
x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype)
40+
y = np.random.uniform(1, 2, [11, 17]).astype(self.dtype)
41+
out = x == y # all elements are not equal
42+
43+
self.inputs = {
44+
'X': OpTest.np_dtype_to_fluid_dtype(x),
45+
'Y': OpTest.np_dtype_to_fluid_dtype(y)
46+
}
47+
self.outputs = {'Out': out}
48+
49+
def set_npu(self):
50+
self.__class__.use_npu = True
51+
52+
def init_dtype(self):
53+
self.dtype = np.float32
54+
55+
def test_check_output(self):
56+
self.check_output_with_place(self.place, check_dygraph=False)
57+
58+
59+
class TestEqual2(TestEqual):
60+
def setUp(self):
61+
self.set_npu()
62+
self.op_type = "equal"
63+
self.place = paddle.NPUPlace(0)
64+
65+
self.init_dtype()
66+
np.random.seed(SEED)
67+
x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype)
68+
y = x.copy()
69+
y[0][1] = 1
70+
out = x == y # all elements are equal, except position [0][1]
71+
72+
self.inputs = {
73+
'X': OpTest.np_dtype_to_fluid_dtype(x),
74+
'Y': OpTest.np_dtype_to_fluid_dtype(y)
75+
}
76+
self.outputs = {'Out': out}
77+
78+
79+
class TestEqual2FP16(TestEqual2):
80+
def init_dtype(self):
81+
self.dtype = np.float16
82+
83+
84+
class TestEqual2Int(TestEqual2):
85+
def init_dtype(self):
86+
self.dtype = np.int32
87+
88+
89+
if __name__ == '__main__':
90+
unittest.main()

0 commit comments

Comments
 (0)