Skip to content

Commit 5d22e15

Browse files
authored
【NPU】Suppert npu kernel for reshape2 op (#31524)
* add reshape2 npu * add reshpe2
1 parent 581e546 commit 5d22e15

File tree

2 files changed

+228
-0
lines changed

2 files changed

+228
-0
lines changed
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <memory>
16+
#include <string>
17+
18+
#include "paddle/fluid/operators/npu_op_runner.h"
19+
20+
namespace paddle {
21+
namespace operators {
22+
23+
template <typename DeviceContext, typename T>
24+
class Reshape2NPUKernel : public framework::OpKernel<T> {
25+
public:
26+
void Compute(const framework::ExecutionContext& ctx) const override {
27+
auto* x = ctx.Input<framework::Tensor>("X");
28+
auto* shape = ctx.Attr<std::vector<int>>> ("shape");
29+
auto* out = ctx.Output<framework::Tensor>("Out");
30+
auto org_shape = framework::vectorize(x->dims());
31+
// reshape
32+
int64_t shape_all = 1;
33+
int64_t org_shape_all = 1;
34+
int index = -1;
35+
for (int i = 0; i < shape.size(); i++) {
36+
if (shape[i] == 0) {
37+
shape[i] = org_shape[i];
38+
}
39+
if (shape[i] == -1) {
40+
index = i;
41+
} else {
42+
shape_all *= shape[i];
43+
}
44+
org_shape_all *= org_shape[i];
45+
}
46+
47+
if (index >= 0) {
48+
shape[index] = org_shape_all / shape_all;
49+
}
50+
out.Resize(framework::make_ddim(shape));
51+
out->mutable_data(ctx.GetPlace(), x->type());
52+
framework::TensorCopy(
53+
*x, ctx.GetPlace(),
54+
ctx.template device_context<platform::DeviceContext>(), out);
55+
out.Resize(framework::make_ddim(shape));
56+
}
57+
};
58+
59+
template <typename DeviceContext, typename T>
60+
class Reshape2GradNPUKernel : public framework::OpKernel<T> {
61+
public:
62+
void Compute(const framework::ExecutionContext& ctx) const override {
63+
auto* d_x = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
64+
auto* d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
65+
auto in_dims = d_x->dims();
66+
67+
d_x->mutable_data(ctx.GetPlace(), d_out->type());
68+
framework::TensorCopy(
69+
*d_out, ctx.GetPlace(),
70+
ctx.template device_context<platform::DeviceContext>(), d_x);
71+
d_x->Resize(in_dims);
72+
}
73+
};
74+
} // namespace operators
75+
} // namespace paddle
76+
77+
namespace ops = paddle::operators;
78+
79+
REGISTER_OP_NPU_KERNEL(
80+
reshpe2, ops::Reshape2NPUKernel<paddle::platform::NPUDeviceContext, float>,
81+
ops::Reshape2NPUKernel<paddle::platform::NPUDeviceContext,
82+
paddle::platform::float16>);
83+
REGISTER_OP_NPU_KERNEL(
84+
reshpe2_grad,
85+
ops::Reshape2GradNPUKernel<paddle::platform::NPUDeviceContext, float>,
86+
ops::Reshape2GradNPUKernel<paddle::platform::NPUDeviceContext,
87+
paddle::platform::float16>);
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
import numpy as np
18+
import unittest
19+
import sys
20+
sys.path.append("..")
21+
from op_test import OpTest
22+
import paddle
23+
import paddle.fluid as fluid
24+
25+
paddle.enable_static()
26+
SEED = 2021
27+
28+
29+
@unittest.skipIf(not paddle.is_compiled_with_npu(),
30+
"core is not compiled with NPU")
31+
class TestReshape2(OpTest):
32+
def setUp(self):
33+
self.set_npu()
34+
self.op_type = "reshape2"
35+
self.place = paddle.NPUPlace(0)
36+
37+
self.init_data()
38+
self.inputs = {"X": np.random.random(self.ori_shape).astype("float32")}
39+
self.attrs = {"shape": self.new_shape}
40+
self.outputs = {
41+
"Out": self.inputs["X"].reshape(self.infered_shape),
42+
'XShape': np.random.random(self.ori_shape).astype("float32")
43+
}
44+
45+
def set_npu(self):
46+
self.__class__.use_npu = True
47+
48+
def init_data(self):
49+
self.ori_shape = (2, 60)
50+
self.new_shape = (12, 10)
51+
self.infered_shape = (12, 10)
52+
53+
def test_check_output(self):
54+
self.check_output(
55+
self.place, check_dygraph=False, no_check_set=['XShape'])
56+
57+
58+
class TestReshape2_case2(TestReshape2):
59+
def init_data(self):
60+
self.ori_shape = (2, 60)
61+
self.new_shape = (-1, 10)
62+
self.infered_shape = (12, 10)
63+
64+
65+
class TestReshape2_case3(TestReshape2):
66+
def init_data(self):
67+
self.ori_shape = (2, 5, 6)
68+
self.new_shape = (-1, 0, 3)
69+
self.infered_shape = (4, 5, 3)
70+
71+
72+
# TODO(ascendrc): Add grad test
73+
# def test_check_grad(self):
74+
# if self.dtype == np.float16:
75+
# return
76+
# self.check_grad(['X'], 'Out')
77+
#
78+
@unittest.skipIf(not paddle.is_compiled_with_npu(),
79+
"core is not compiled with NPU")
80+
class TestReshapeNet(unittest.TestCase):
81+
def _test(self, run_npu=True):
82+
main_prog = paddle.static.Program()
83+
startup_prog = paddle.static.Program()
84+
main_prog.random_seed = SEED
85+
startup_prog.random_seed = SEED
86+
np.random.seed(SEED)
87+
88+
a_np = np.random.random(size=(32, 32)).astype('float32')
89+
b_np = np.random.random(size=(32, 32)).astype('float32')
90+
label_np = np.random.randint(2, size=(32, 1)).astype('int64')
91+
92+
with paddle.static.program_guard(main_prog, startup_prog):
93+
a = paddle.static.data(name="a", shape=[32, 32], dtype='float32')
94+
b = paddle.static.data(name="b", shape=[32, 32], dtype='float32')
95+
label = paddle.static.data(
96+
name="label", shape=[32, 1], dtype='int64')
97+
98+
sum = paddle.add(a, b)
99+
z = paddle.reshape(sum, shape=[32, 32])
100+
101+
fc_1 = fluid.layers.fc(input=z, size=128)
102+
prediction = fluid.layers.fc(input=fc_1, size=2, act='softmax')
103+
104+
cost = fluid.layers.cross_entropy(input=prediction, label=label)
105+
loss = fluid.layers.reduce_mean(cost)
106+
sgd = fluid.optimizer.SGD(learning_rate=0.01)
107+
sgd.minimize(loss)
108+
109+
if run_npu:
110+
place = paddle.NPUPlace(0)
111+
else:
112+
place = paddle.CPUPlace()
113+
114+
exe = paddle.static.Executor(place)
115+
exe.run(startup_prog)
116+
117+
print("Start run on {}".format(place))
118+
for epoch in range(100):
119+
120+
pred_res, loss_res = exe.run(
121+
main_prog,
122+
feed={"a": a_np,
123+
"b": b_np,
124+
"label": label_np},
125+
fetch_list=[prediction, loss])
126+
if epoch % 10 == 0:
127+
print("Epoch {} | Prediction[0]: {}, Loss: {}".format(
128+
epoch, pred_res[0], loss_res))
129+
130+
return pred_res, loss_res
131+
132+
def test_npu(self):
133+
cpu_pred, cpu_loss = self._test(False)
134+
npu_pred, npu_loss = self._test(True)
135+
136+
self.assertTrue(np.allclose(npu_pred, cpu_pred))
137+
self.assertTrue(np.allclose(npu_loss, cpu_loss))
138+
139+
140+
if __name__ == '__main__':
141+
unittest.main()

0 commit comments

Comments
 (0)