Skip to content

Commit 6151ccd

Browse files
authored
[NPU] Support npu op: (1) cos (2) cos_grad (#34573)
* [NPU] Support npu op: (1) cos (2) cos_grad * Update test_cos_op_npu.py * Update activation_op_npu.cc * rm redundant {1}
1 parent 6839994 commit 6151ccd

File tree

2 files changed

+211
-0
lines changed

2 files changed

+211
-0
lines changed

paddle/fluid/operators/activation_op_npu.cc

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,61 @@ class ReciprocalGradNPUKernel : public framework::OpKernel<T> {
472472
}
473473
};
474474

475+
template <typename DeviceContext, typename T>
476+
class CosNPUKernel : public framework::OpKernel<T> {
477+
public:
478+
void Compute(const framework::ExecutionContext& ctx) const override {
479+
auto* x = ctx.Input<Tensor>("X");
480+
auto* out = ctx.Output<Tensor>("Out");
481+
482+
auto place = ctx.GetPlace();
483+
out->mutable_data<T>(place);
484+
485+
auto stream =
486+
ctx.template device_context<paddle::platform::NPUDeviceContext>()
487+
.stream();
488+
489+
const auto& runner = NpuOpRunner("Cos", {*x}, {*out}, {});
490+
runner.Run(stream);
491+
}
492+
};
493+
494+
template <typename DeviceContext, typename T>
495+
class CosGradNPUKernel : public framework::OpKernel<T> {
496+
public:
497+
void Compute(const framework::ExecutionContext& ctx) const override {
498+
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
499+
auto* x = ctx.Input<Tensor>("X");
500+
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
501+
502+
auto place = ctx.GetPlace();
503+
dx->mutable_data<T>(place);
504+
505+
Tensor sin_out(x->type()); // Temporary Tensor
506+
sin_out.Resize(x->dims());
507+
sin_out.mutable_data<T>(place);
508+
509+
auto stream =
510+
ctx.template device_context<paddle::platform::NPUDeviceContext>()
511+
.stream();
512+
const auto& runner = NpuOpRunner("Sin", {*x}, {sin_out}, {});
513+
runner.Run(stream);
514+
515+
const auto& runner_dx = NpuOpRunner("Mul", {*dout, sin_out}, {*dx}, {});
516+
runner_dx.Run(stream);
517+
518+
Tensor tmp(x->type()); // Temporary Tensor
519+
tmp.Resize(framework::make_ddim({1, 1}));
520+
tmp.mutable_data<T>(place);
521+
float factor = -1.;
522+
FillNpuTensorWithConstant<T>(&tmp, static_cast<T>(factor));
523+
524+
const auto& runner_dx_ = NpuOpRunner("Xdivy", {*dx, tmp}, {*dx}, {});
525+
runner_dx_.Run(stream);
526+
// dx = -dout * Sine(x);
527+
}
528+
};
529+
475530
} // namespace operators
476531
} // namespace paddle
477532

@@ -583,3 +638,13 @@ REGISTER_OP_NPU_KERNEL(
583638
ops::ReciprocalGradNPUKernel<paddle::platform::NPUDeviceContext, double>,
584639
ops::ReciprocalGradNPUKernel<paddle::platform::NPUDeviceContext,
585640
paddle::platform::float16>);
641+
642+
REGISTER_OP_NPU_KERNEL(
643+
cos, ops::CosNPUKernel<paddle::platform::NPUDeviceContext, float>,
644+
ops::CosNPUKernel<paddle::platform::NPUDeviceContext,
645+
paddle::platform::float16>);
646+
647+
REGISTER_OP_NPU_KERNEL(
648+
cos_grad, ops::CosGradNPUKernel<paddle::platform::NPUDeviceContext, float>,
649+
ops::CosGradNPUKernel<paddle::platform::NPUDeviceContext,
650+
paddle::platform::float16>);
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
import numpy as np
18+
import unittest
19+
import sys
20+
sys.path.append("..")
21+
from op_test import OpTest
22+
import paddle
23+
import paddle.fluid as fluid
24+
25+
paddle.enable_static()
26+
SEED = 2021
27+
28+
29+
class TestCos(OpTest):
30+
def setUp(self):
31+
self.set_npu()
32+
self.op_type = "cos"
33+
self.place = paddle.NPUPlace(0)
34+
35+
self.init_dtype()
36+
np.random.seed(SEED)
37+
x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype)
38+
out = np.cos(x)
39+
40+
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
41+
self.attrs = {}
42+
self.outputs = {'Out': out}
43+
44+
def set_npu(self):
45+
self.__class__.use_npu = True
46+
47+
def init_dtype(self):
48+
self.dtype = np.float32
49+
50+
def test_check_output(self):
51+
self.check_output_with_place(self.place, atol=1e-7)
52+
53+
def test_check_grad(self):
54+
if self.dtype == np.float16:
55+
return
56+
self.check_grad_with_place(self.place, ['X'], 'Out')
57+
58+
59+
class TestCosFp16(OpTest):
60+
def setUp(self):
61+
self.set_npu()
62+
self.op_type = "cos"
63+
self.place = paddle.NPUPlace(0)
64+
65+
self.init_dtype()
66+
np.random.seed(SEED)
67+
x = np.random.uniform(1, 2, [3, 4]).astype(self.dtype)
68+
out = np.cos(x)
69+
70+
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
71+
self.attrs = {}
72+
self.outputs = {'Out': out}
73+
74+
def set_npu(self):
75+
self.__class__.use_npu = True
76+
self.__class__.no_need_check_grad = True
77+
78+
def init_dtype(self):
79+
self.dtype = np.float16
80+
81+
def test_check_output(self):
82+
self.check_output_with_place(self.place)
83+
84+
85+
class TestCosNet(unittest.TestCase):
86+
def _test(self, run_npu=True):
87+
main_prog = paddle.static.Program()
88+
startup_prog = paddle.static.Program()
89+
main_prog.random_seed = SEED
90+
startup_prog.random_seed = SEED
91+
np.random.seed(SEED)
92+
93+
a_np = np.random.random(size=(32, 32)).astype('float32')
94+
b_np = np.random.random(size=(32, 32)).astype('float32')
95+
label_np = np.random.randint(2, size=(32, 1)).astype('int64')
96+
97+
with paddle.static.program_guard(main_prog, startup_prog):
98+
a = paddle.static.data(name="a", shape=[32, 32], dtype='float32')
99+
b = paddle.static.data(name="b", shape=[32, 32], dtype='float32')
100+
label = paddle.static.data(
101+
name="label", shape=[32, 1], dtype='int64')
102+
103+
c = paddle.multiply(a, b)
104+
d = paddle.cos(c)
105+
106+
fc_1 = fluid.layers.fc(input=d, size=128)
107+
prediction = fluid.layers.fc(input=fc_1, size=2, act='softmax')
108+
109+
cost = fluid.layers.cross_entropy(input=prediction, label=label)
110+
loss = fluid.layers.reduce_mean(cost)
111+
sgd = fluid.optimizer.SGD(learning_rate=0.01)
112+
sgd.minimize(loss)
113+
114+
if run_npu:
115+
place = paddle.NPUPlace(0)
116+
else:
117+
place = paddle.CPUPlace()
118+
119+
exe = paddle.static.Executor(place)
120+
exe.run(startup_prog)
121+
122+
print("Start run on {}".format(place))
123+
for epoch in range(100):
124+
125+
pred_res, loss_res = exe.run(
126+
main_prog,
127+
feed={"a": a_np,
128+
"b": b_np,
129+
"label": label_np},
130+
fetch_list=[prediction, loss])
131+
if epoch % 10 == 0:
132+
print("Epoch {} | Prediction[0]: {}, Loss: {}".format(
133+
epoch, pred_res[0], loss_res))
134+
135+
return pred_res, loss_res
136+
137+
def test_npu(self):
138+
cpu_pred, cpu_loss = self._test(False)
139+
npu_pred, npu_loss = self._test(True)
140+
141+
self.assertTrue(np.allclose(npu_pred, cpu_pred))
142+
self.assertTrue(np.allclose(npu_loss, cpu_loss))
143+
144+
145+
if __name__ == '__main__':
146+
unittest.main()

0 commit comments

Comments
 (0)