Skip to content

Commit 581e546

Browse files
authored
【NPU】add relu op for npu (#31515)
* add relu npu * fixed * fix
1 parent cfeeb4b commit 581e546

File tree

2 files changed

+227
-0
lines changed

2 files changed

+227
-0
lines changed

paddle/fluid/operators/activation_op_npu.cc

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,46 @@ class PowGradNPUKernel : public framework::OpKernel<T> {
103103
}
104104
};
105105

106+
template <typename DeviceContext, typename T>
107+
class ReluNPUKernel : public framework::OpKernel<T> {
108+
public:
109+
void Compute(const framework::ExecutionContext& ctx) const override {
110+
auto* x = ctx.Input<Tensor>("X");
111+
auto* out = ctx.Output<Tensor>("Out");
112+
113+
out->mutable_data<T>(ctx.GetPlace());
114+
115+
auto runner = NpuOpRunner("Relu",
116+
{
117+
*x,
118+
},
119+
{*out}, {});
120+
121+
auto stream =
122+
ctx.template device_context<paddle::platform::NPUDeviceContext>()
123+
.stream();
124+
runner.Run(stream);
125+
}
126+
};
127+
128+
template <typename DeviceContext, typename T>
129+
class ReluGradNPUKernel : public framework::OpKernel<T> {
130+
public:
131+
void Compute(const framework::ExecutionContext& ctx) const override {
132+
auto* out = ctx.Input<Tensor>("Out");
133+
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
134+
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
135+
136+
auto stream =
137+
ctx.template device_context<paddle::platform::NPUDeviceContext>()
138+
.stream();
139+
140+
dx->mutable_data<T>(ctx.GetPlace());
141+
auto runner = NpuOpRunner("ReluGrad", {*dout, *out}, {*dx}, {});
142+
143+
runner.Run(stream);
144+
}
145+
};
106146
} // namespace operators
107147
} // namespace paddle
108148

@@ -117,3 +157,14 @@ REGISTER_OP_NPU_KERNEL(
117157
pow_grad, ops::PowGradNPUKernel<paddle::platform::NPUDeviceContext, float>,
118158
ops::PowGradNPUKernel<paddle::platform::NPUDeviceContext,
119159
paddle::platform::float16>);
160+
161+
REGISTER_OP_NPU_KERNEL(
162+
relu, ops::ReluNPUKernel<paddle::platform::NPUDeviceContext, float>,
163+
ops::ReluNPUKernel<paddle::platform::NPUDeviceContext,
164+
paddle::platform::float16>);
165+
166+
REGISTER_OP_NPU_KERNEL(
167+
relu_grad,
168+
ops::ReluGradNPUKernel<paddle::platform::NPUDeviceContext, float>,
169+
ops::ReluGradNPUKernel<paddle::platform::NPUDeviceContext,
170+
paddle::platform::float16>);
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
import numpy as np
18+
import unittest
19+
import sys
20+
sys.path.append("..")
21+
from op_test import OpTest
22+
import paddle
23+
import paddle.fluid as fluid
24+
25+
paddle.enable_static()
26+
SEED = 2021
27+
28+
29+
@unittest.skipIf(not paddle.is_compiled_with_npu(),
30+
"core is not compiled with NPU")
31+
class TestRelu(OpTest):
32+
def setUp(self):
33+
self.set_npu()
34+
self.op_type = "relu"
35+
self.place = paddle.NPUPlace(0)
36+
37+
self.init_dtype()
38+
np.random.seed(SEED)
39+
x = np.random.rand(3, 2).astype(self.dtype)
40+
out = x
41+
42+
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
43+
self.attrs = {}
44+
self.outputs = {'Out': out}
45+
46+
def set_npu(self):
47+
self.__class__.use_npu = True
48+
49+
def init_dtype(self):
50+
self.dtype = np.float32
51+
52+
def test_check_output(self):
53+
self.check_output_with_place(self.place, check_dygraph=False)
54+
55+
56+
@unittest.skipIf(not paddle.is_compiled_with_npu(),
57+
"core is not compiled with NPU")
58+
class TestReluFp16(OpTest):
59+
def setUp(self):
60+
self.set_npu()
61+
self.op_type = "relu"
62+
self.place = paddle.NPUPlace(0)
63+
64+
self.init_dtype()
65+
np.random.seed(SEED)
66+
x = np.random.rand(3, 2).astype(self.dtype)
67+
out = x
68+
69+
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
70+
self.attrs = {}
71+
self.outputs = {'Out': out}
72+
73+
def set_npu(self):
74+
self.__class__.use_npu = True
75+
self.__class__.no_need_check_grad = True
76+
77+
def init_dtype(self):
78+
self.dtype = np.float16
79+
80+
def test_check_output(self):
81+
self.check_output_with_place(self.place, check_dygraph=False, atol=1e-5)
82+
83+
84+
@unittest.skipIf(not paddle.is_compiled_with_npu(),
85+
"core is not compiled with NPU")
86+
class TestReluNeg(OpTest):
87+
def setUp(self):
88+
self.set_npu()
89+
self.op_type = "relu"
90+
self.place = paddle.NPUPlace(0)
91+
92+
self.init_dtype()
93+
np.random.seed(SEED)
94+
x = np.array([0.1, -0.1, -1.0]).astype(self.dtype)
95+
out = np.array([0.1, 0.0, 0.0]).astype(self.dtype)
96+
97+
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
98+
self.attrs = {}
99+
self.outputs = {'Out': out}
100+
101+
def set_npu(self):
102+
self.__class__.use_npu = True
103+
104+
def init_dtype(self):
105+
self.dtype = np.float32
106+
107+
def test_check_output(self):
108+
self.check_output_with_place(self.place, check_dygraph=False)
109+
110+
111+
#
112+
#
113+
@unittest.skipIf(not paddle.is_compiled_with_npu(),
114+
"core is not compiled with NPU")
115+
class TestReluNet(unittest.TestCase):
116+
def _test(self, run_npu=True):
117+
main_prog = paddle.static.Program()
118+
startup_prog = paddle.static.Program()
119+
main_prog.random_seed = SEED
120+
startup_prog.random_seed = SEED
121+
np.random.seed(SEED)
122+
123+
a_np = np.random.random(size=(32, 32)).astype('float32')
124+
b_np = np.random.random(size=(32, 32)).astype('float32')
125+
label_np = np.random.randint(2, size=(32, 1)).astype('int64')
126+
127+
with paddle.static.program_guard(main_prog, startup_prog):
128+
a = paddle.static.data(name="a", shape=[32, 32], dtype='float32')
129+
b = paddle.static.data(name="b", shape=[32, 32], dtype='float32')
130+
label = paddle.static.data(
131+
name="label", shape=[32, 1], dtype='int64')
132+
133+
sum = paddle.add(a, b)
134+
z = paddle.nn.functional.relu(sum)
135+
136+
fc_1 = fluid.layers.fc(input=z, size=128)
137+
prediction = fluid.layers.fc(input=fc_1, size=2, act='softmax')
138+
139+
cost = fluid.layers.cross_entropy(input=prediction, label=label)
140+
loss = fluid.layers.reduce_mean(cost)
141+
sgd = fluid.optimizer.SGD(learning_rate=0.01)
142+
sgd.minimize(loss)
143+
144+
if run_npu:
145+
place = paddle.NPUPlace(0)
146+
else:
147+
place = paddle.CPUPlace()
148+
149+
exe = paddle.static.Executor(place)
150+
exe.run(startup_prog)
151+
152+
print("Start run on {}".format(place))
153+
for epoch in range(100):
154+
155+
pred_res, loss_res = exe.run(
156+
main_prog,
157+
feed={"a": a_np,
158+
"b": b_np,
159+
"label": label_np},
160+
fetch_list=[prediction, loss])
161+
if epoch % 10 == 0:
162+
print("Epoch {} | Prediction[0]: {}, Loss: {}".format(
163+
epoch, pred_res[0], loss_res))
164+
165+
return pred_res, loss_res
166+
167+
def test_npu(self):
168+
cpu_pred, cpu_loss = self._test(False)
169+
npu_pred, npu_loss = self._test(True)
170+
171+
self.assertTrue(np.allclose(npu_pred, cpu_pred))
172+
self.assertTrue(np.allclose(npu_loss, cpu_loss))
173+
174+
175+
if __name__ == '__main__':
176+
unittest.main()

0 commit comments

Comments
 (0)