Skip to content

Commit 6839994

Browse files
authored
[NPU] Add relu6 and relu6_grad npu op (#34596)
* Add relu6 and relu6_grad npu op * fixed pre-commit-config.yaml * fixed for CI
1 parent 012d12b commit 6839994

File tree

2 files changed

+218
-0
lines changed

2 files changed

+218
-0
lines changed

paddle/fluid/operators/activation_op_npu.cc

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,47 @@ class ReluGradNPUKernel : public framework::OpKernel<T> {
144144
}
145145
};
146146

147+
template <typename DeviceContext, typename T>
148+
class Relu6NPUKernel : public framework::OpKernel<T> {
149+
public:
150+
void Compute(const framework::ExecutionContext& ctx) const override {
151+
auto* x = ctx.Input<Tensor>("X");
152+
auto* out = ctx.Output<Tensor>("Out");
153+
154+
out->mutable_data<T>(ctx.GetPlace());
155+
156+
const auto& runner = NpuOpRunner("Relu6",
157+
{
158+
*x,
159+
},
160+
{*out}, {});
161+
162+
auto stream =
163+
ctx.template device_context<paddle::platform::NPUDeviceContext>()
164+
.stream();
165+
runner.Run(stream);
166+
}
167+
};
168+
169+
template <typename DeviceContext, typename T>
170+
class Relu6GradNPUKernel : public framework::OpKernel<T> {
171+
public:
172+
void Compute(const framework::ExecutionContext& ctx) const override {
173+
auto* out = ctx.Input<Tensor>("Out");
174+
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
175+
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
176+
177+
auto stream =
178+
ctx.template device_context<paddle::platform::NPUDeviceContext>()
179+
.stream();
180+
181+
dx->mutable_data<T>(ctx.GetPlace());
182+
const auto& runner = NpuOpRunner("Relu6Grad", {*dout, *out}, {*dx}, {});
183+
184+
runner.Run(stream);
185+
}
186+
};
187+
147188
template <typename DeviceContext, typename T>
148189
class SqrtNPUKernel : public framework::OpKernel<T> {
149190
public:
@@ -457,6 +498,17 @@ REGISTER_OP_NPU_KERNEL(
457498
ops::ReluGradNPUKernel<paddle::platform::NPUDeviceContext,
458499
paddle::platform::float16>);
459500

501+
REGISTER_OP_NPU_KERNEL(
502+
relu6, ops::Relu6NPUKernel<paddle::platform::NPUDeviceContext, float>,
503+
ops::Relu6NPUKernel<paddle::platform::NPUDeviceContext,
504+
paddle::platform::float16>);
505+
506+
REGISTER_OP_NPU_KERNEL(
507+
relu6_grad,
508+
ops::Relu6GradNPUKernel<paddle::platform::NPUDeviceContext, float>,
509+
ops::Relu6GradNPUKernel<paddle::platform::NPUDeviceContext,
510+
paddle::platform::float16>);
511+
460512
REGISTER_OP_NPU_KERNEL(
461513
sqrt, ops::SqrtNPUKernel<paddle::platform::NPUDeviceContext, float>,
462514
ops::SqrtNPUKernel<paddle::platform::NPUDeviceContext,
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
import paddle.fluid as fluid
17+
import paddle
18+
from op_test import OpTest
19+
20+
import numpy as np
21+
import unittest
22+
import sys
23+
sys.path.append("..")
24+
25+
paddle.enable_static()
26+
SEED = 2021
27+
28+
29+
def ref_relu6(x, threshold=6.0):
30+
out = np.copy(x)
31+
out[np.abs(x - threshold) < 0.005] = threshold + 0.02
32+
out = np.minimum(np.maximum(x, 0), threshold)
33+
return out
34+
35+
36+
class TestRelu6(OpTest):
37+
def setUp(self):
38+
self.set_npu()
39+
self.op_type = "relu6"
40+
self.place = paddle.NPUPlace(0)
41+
42+
self.init_dtype()
43+
np.random.seed(SEED)
44+
x = np.random.uniform(-1, 10, [10, 12]).astype(self.dtype)
45+
x[np.abs(x) < 0.005] = 0.02
46+
out = ref_relu6(x)
47+
48+
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
49+
self.attrs = {'threshold': 6.0}
50+
self.outputs = {'Out': out}
51+
52+
def set_npu(self):
53+
self.__class__.use_npu = True
54+
55+
def test_check_output(self):
56+
self.check_output_with_place(self.place)
57+
58+
def test_check_grad(self):
59+
if self.dtype == np.float16:
60+
return
61+
self.check_grad_with_place(self.place, ['X'], 'Out')
62+
63+
def init_dtype(self):
64+
self.dtype = np.float32
65+
66+
67+
class TestRelu6Float16(TestRelu6):
68+
def set_npu(self):
69+
self.__class__.use_npu = True
70+
self.__class__.no_need_check_grad = True
71+
72+
def set_attrs(self):
73+
self.dtype = np.float16
74+
75+
def test_check_output(self):
76+
self.check_output_with_place(self.place)
77+
78+
79+
class TestReluNeg(TestRelu6):
80+
def setUp(self):
81+
self.set_npu()
82+
self.op_type = "relu6"
83+
self.place = paddle.NPUPlace(0)
84+
85+
self.init_dtype()
86+
np.random.seed(SEED)
87+
x = np.random.uniform(-10, -1, [10, 12]).astype(self.dtype)
88+
x[np.abs(x) < 0.005] = 0.02
89+
out = ref_relu6(x)
90+
91+
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
92+
self.attrs = {'threshold': 6.0}
93+
self.outputs = {'Out': out}
94+
95+
def set_npu(self):
96+
self.__class__.use_npu = True
97+
98+
def init_dtype(self):
99+
self.dtype = np.float32
100+
101+
def test_check_output(self):
102+
self.check_output_with_place(self.place)
103+
104+
105+
class TestRelu6Net(unittest.TestCase):
106+
def _test(self, run_npu=True):
107+
main_prog = paddle.static.Program()
108+
startup_prog = paddle.static.Program()
109+
main_prog.random_seed = SEED
110+
startup_prog.random_seed = SEED
111+
np.random.seed(SEED)
112+
113+
a_np = np.random.random(size=(32, 32)).astype('float32')
114+
b_np = np.random.random(size=(32, 32)).astype('float32')
115+
label_np = np.random.randint(2, size=(32, 1)).astype('int64')
116+
117+
with paddle.static.program_guard(main_prog, startup_prog):
118+
a = paddle.static.data(name="a", shape=[32, 32], dtype='float32')
119+
b = paddle.static.data(name="b", shape=[32, 32], dtype='float32')
120+
label = paddle.static.data(
121+
name="label", shape=[32, 1], dtype='int64')
122+
123+
sum = paddle.add(a, b)
124+
z = paddle.nn.functional.relu6(sum)
125+
126+
fc_1 = fluid.layers.fc(input=z, size=128)
127+
prediction = fluid.layers.fc(input=fc_1, size=2, act='softmax')
128+
129+
cost = fluid.layers.cross_entropy(input=prediction, label=label)
130+
loss = fluid.layers.reduce_mean(cost)
131+
sgd = fluid.optimizer.SGD(learning_rate=0.01)
132+
sgd.minimize(loss)
133+
134+
if run_npu:
135+
place = paddle.NPUPlace(0)
136+
else:
137+
place = paddle.CPUPlace()
138+
139+
exe = paddle.static.Executor(place)
140+
exe.run(startup_prog)
141+
142+
print("Start run on {}".format(place))
143+
for epoch in range(100):
144+
145+
pred_res, loss_res = exe.run(
146+
main_prog,
147+
feed={"a": a_np,
148+
"b": b_np,
149+
"label": label_np},
150+
fetch_list=[prediction, loss])
151+
if epoch % 10 == 0:
152+
print("Epoch {} | Prediction[0]: {}, Loss: {}".format(
153+
epoch, pred_res[0], loss_res))
154+
155+
return pred_res, loss_res
156+
157+
def test_npu(self):
158+
cpu_pred, cpu_loss = self._test(False)
159+
npu_pred, npu_loss = self._test(True)
160+
161+
self.assertTrue(np.allclose(npu_pred, cpu_pred))
162+
self.assertTrue(np.allclose(npu_loss, cpu_loss))
163+
164+
165+
if __name__ == '__main__':
166+
unittest.main()

0 commit comments

Comments
 (0)