Skip to content

Commit 2164ad6

Browse files
[npu]add unsqueeze2_grad,test=develop (#34733)
1 parent 3f71e8d commit 2164ad6

File tree

2 files changed

+157
-0
lines changed

2 files changed

+157
-0
lines changed

paddle/fluid/operators/unsqueeze_op_npu.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,4 +38,20 @@ REGISTER_OP_NPU_KERNEL(
3838
ops::UnsqueezeKernel<plat::NPUDeviceContext, int>,
3939
ops::UnsqueezeKernel<plat::NPUDeviceContext, int8_t>,
4040
ops::UnsqueezeKernel<plat::NPUDeviceContext, int64_t>);
41+
REGISTER_OP_NPU_KERNEL(
42+
unsqueeze_grad, ops::UnsqueezeGradKernel<plat::NPUDeviceContext, float>,
43+
ops::UnsqueezeGradKernel<plat::NPUDeviceContext, double>,
44+
ops::UnsqueezeGradKernel<plat::NPUDeviceContext, plat::float16>,
45+
ops::UnsqueezeGradKernel<plat::NPUDeviceContext, bool>,
46+
ops::UnsqueezeGradKernel<plat::NPUDeviceContext, int>,
47+
ops::UnsqueezeGradKernel<plat::NPUDeviceContext, int8_t>,
48+
ops::UnsqueezeGradKernel<plat::NPUDeviceContext, int64_t>);
49+
REGISTER_OP_NPU_KERNEL(
50+
unsqueeze2_grad, ops::Unsqueeze2GradKernel<plat::NPUDeviceContext, float>,
51+
ops::Unsqueeze2GradKernel<plat::NPUDeviceContext, double>,
52+
ops::Unsqueeze2GradKernel<plat::NPUDeviceContext, plat::float16>,
53+
ops::Unsqueeze2GradKernel<plat::NPUDeviceContext, bool>,
54+
ops::Unsqueeze2GradKernel<plat::NPUDeviceContext, int>,
55+
ops::Unsqueeze2GradKernel<plat::NPUDeviceContext, int8_t>,
56+
ops::Unsqueeze2GradKernel<plat::NPUDeviceContext, int64_t>);
4157
#endif
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
import numpy as np
18+
import unittest
19+
import sys
20+
sys.path.append("..")
21+
from op_test import OpTest
22+
import paddle
23+
import paddle.fluid as fluid
24+
import paddle.fluid.core as core
25+
from paddle.fluid import Program, program_guard
26+
27+
paddle.enable_static()
28+
29+
30+
# unsqueeze
31+
class TestUnsqueezeOp(OpTest):
32+
def setUp(self):
33+
self.set_npu()
34+
self.op_type = "unsqueeze"
35+
self.place = paddle.NPUPlace(0)
36+
self.init_test_case()
37+
self.x = np.random.random(self.ori_shape).astype("float32")
38+
self.inputs = {"X": OpTest.np_dtype_to_fluid_dtype(self.x)}
39+
self.init_attrs()
40+
self.outputs = {"Out": self.x.reshape(self.new_shape), }
41+
42+
def set_npu(self):
43+
self.__class__.use_npu = True
44+
45+
def test_check_output(self):
46+
self.check_output_with_place(self.place)
47+
48+
def test_check_grad(self):
49+
self.check_grad_with_place(self.place, ["X"], "Out")
50+
51+
def init_test_case(self):
52+
self.ori_shape = (3, 40)
53+
self.axes = (0, 2)
54+
self.new_shape = (1, 3, 1, 40)
55+
56+
def init_attrs(self):
57+
self.attrs = {"axes": self.axes}
58+
59+
60+
class TestUnsqueezeOp1(TestUnsqueezeOp):
61+
def init_test_case(self):
62+
self.ori_shape = (3, 40)
63+
self.axes = (0, -2)
64+
self.new_shape = (1, 3, 1, 40)
65+
66+
67+
# No axes input.
68+
class TestUnsqueezeOp2(TestUnsqueezeOp):
69+
def init_test_case(self):
70+
self.ori_shape = (20, 5)
71+
self.axes = ()
72+
self.new_shape = (1, 20, 5)
73+
74+
75+
# Just part of axes be squeezed.
76+
class TestUnsqueezeOp3(TestUnsqueezeOp):
77+
def init_test_case(self):
78+
self.ori_shape = (6, 5, 1, 4)
79+
self.axes = (1, -1)
80+
self.new_shape = (6, 1, 5, 1, 4, 1)
81+
82+
83+
# unsqueeze 2
84+
class TestUnsqueeze2Op(OpTest):
85+
def setUp(self):
86+
self.set_npu()
87+
self.op_type = "unsqueeze2"
88+
self.place = paddle.NPUPlace(0)
89+
self.init_test_case()
90+
self.x = np.random.random(self.ori_shape).astype("float32")
91+
self.inputs = {"X": OpTest.np_dtype_to_fluid_dtype(self.x)}
92+
self.init_attrs()
93+
self.outputs = {
94+
"Out": self.x.reshape(self.new_shape),
95+
"XShape": np.random.random(self.ori_shape).astype("float32")
96+
}
97+
98+
def set_npu(self):
99+
self.__class__.use_npu = True
100+
101+
def test_check_output(self):
102+
self.check_output_with_place(self.place, no_check_set=['XShape'])
103+
104+
def test_check_grad(self):
105+
self.check_grad_with_place(self.place, ["X"], "Out")
106+
107+
def init_test_case(self):
108+
self.ori_shape = (3, 40)
109+
self.axes = (0, 2)
110+
self.new_shape = (1, 3, 1, 40)
111+
112+
def init_attrs(self):
113+
self.attrs = {"axes": self.axes}
114+
115+
116+
# Correct: There is mins axis.
117+
class TestUnsqueeze2Op1(TestUnsqueeze2Op):
118+
def init_test_case(self):
119+
self.ori_shape = (20, 5)
120+
self.axes = (0, -2)
121+
self.new_shape = (1, 20, 1, 5)
122+
123+
124+
# Correct: No axes input.
125+
class TestUnsqueeze2Op2(TestUnsqueeze2Op):
126+
def init_test_case(self):
127+
self.ori_shape = (20, 5)
128+
self.axes = ()
129+
self.new_shape = (1, 20, 5)
130+
131+
132+
# Correct: Just part of axes be squeezed.
133+
class TestUnsqueeze2Op3(TestUnsqueeze2Op):
134+
def init_test_case(self):
135+
self.ori_shape = (6, 5, 1, 4)
136+
self.axes = (1, -1)
137+
self.new_shape = (6, 1, 5, 1, 4, 1)
138+
139+
140+
if __name__ == "__main__":
141+
unittest.main()

0 commit comments

Comments
 (0)