Skip to content

Commit 7aadfc9

Browse files
committed
[NPU] Support npu op expand_v2 and expand_v2_grad
1 parent 4d80f3d commit 7aadfc9

File tree

2 files changed

+99
-29
lines changed

2 files changed

+99
-29
lines changed

paddle/fluid/operators/expand_v2_op_npu.cc

100644100755
Lines changed: 89 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,33 @@ limitations under the License. */
1818
namespace paddle {
1919
namespace operators {
2020

21+
inline std::vector<int> get_expand_shape_npu(
22+
const framework::ExecutionContext& ctx) {
23+
std::vector<int> vec_expand_shape;
24+
auto list_expand_shapes_tensor =
25+
ctx.MultiInput<framework::Tensor>("expand_shapes_tensor");
26+
if (ctx.HasInput("Shape")) {
27+
auto* shape_tensor = ctx.Input<framework::LoDTensor>("Shape");
28+
std::vector<int> out_data;
29+
TensorToVector(*shape_tensor, ctx.device_context(), &out_data);
30+
for (int i = 0; i < static_cast<int>(out_data.size()); ++i) {
31+
vec_expand_shape.push_back(out_data[i]);
32+
}
33+
return vec_expand_shape;
34+
} else if (list_expand_shapes_tensor.size() > 0) {
35+
// get tensor from
36+
for (size_t i = 0; i < list_expand_shapes_tensor.size(); ++i) {
37+
auto tensor = list_expand_shapes_tensor[i];
38+
std::vector<int> out_data;
39+
TensorToVector(*tensor, ctx.device_context(), &out_data);
40+
vec_expand_shape.push_back(out_data[0]);
41+
}
42+
return vec_expand_shape;
43+
} else {
44+
return ctx.Attr<std::vector<int>>("shape");
45+
}
46+
}
47+
2148
using Tensor = framework::Tensor;
2249
template <typename DeviceContext, typename T>
2350
class ExpandV2NPUKernel : public framework::OpKernel<T> {
@@ -26,27 +53,7 @@ class ExpandV2NPUKernel : public framework::OpKernel<T> {
2653
auto* X = ctx.Input<framework::Tensor>("X");
2754
auto* Out = ctx.Output<framework::Tensor>("Out");
2855

29-
std::vector<int> expand_shape;
30-
auto list_expand_shapes_tensor =
31-
ctx.MultiInput<framework::Tensor>("expand_shapes_tensor");
32-
if (ctx.HasInput("Shape")) {
33-
auto* shape_tensor = ctx.Input<framework::LoDTensor>("Shape");
34-
std::vector<int> out_data;
35-
TensorToVector(*shape_tensor, ctx.device_context(), &out_data);
36-
for (int i = 0; i < static_cast<int>(out_data.size()); ++i) {
37-
expand_shape.push_back(out_data[i]);
38-
}
39-
} else if (list_expand_shapes_tensor.size() > 0) {
40-
// get tensor from
41-
for (size_t i = 0; i < list_expand_shapes_tensor.size(); ++i) {
42-
auto tensor = list_expand_shapes_tensor[i];
43-
std::vector<int> out_data;
44-
TensorToVector(*tensor, ctx.device_context(), &out_data);
45-
expand_shape.push_back(out_data[0]);
46-
}
47-
} else {
48-
expand_shape = ctx.Attr<std::vector<int>>("shape");
49-
}
56+
std::vector<int> expand_shape = get_expand_shape_npu(ctx);
5057

5158
framework::NPUAttributeMap attr_input = {{"shape", expand_shape}};
5259

@@ -97,6 +104,62 @@ class ExpandV2NPUKernel : public framework::OpKernel<T> {
97104
runner.Run(stream);
98105
}
99106
};
107+
108+
template <typename DeviceContext, typename T>
109+
class ExpandV2NPUGradKernel : public framework::OpKernel<T> {
110+
public:
111+
void Compute(const framework::ExecutionContext& ctx) const override {
112+
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
113+
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
114+
dx->mutable_data<T>(ctx.GetPlace());
115+
116+
auto stream =
117+
ctx.template device_context<paddle::platform::NPUDeviceContext>()
118+
.stream();
119+
120+
// case 1: reduce dout dims to dx dims
121+
// For example: [2, 120] --> [120]
122+
auto reduce_ndim = dout->dims().size() - dx->dims().size();
123+
std::vector<int> axes;
124+
for (auto i = 0; i < reduce_ndim; ++i) {
125+
axes.push_back(i);
126+
}
127+
Tensor* tmp_dout = const_cast<Tensor*>(dout);
128+
Tensor reduced_dout(dx->type());
129+
if (axes.size() != 0) {
130+
std::vector<int64_t> reduced_dout_dims;
131+
for (auto i = reduce_ndim; i < dout->dims().size(); ++i) {
132+
reduced_dout_dims.push_back(dout->dims()[i]);
133+
}
134+
reduced_dout.Resize(framework::make_ddim(reduced_dout_dims));
135+
reduced_dout.mutable_data<T>(ctx.GetPlace());
136+
const auto& runner = NpuOpRunner("ReduceSumD", {*dout}, {reduced_dout},
137+
{{"axes", axes}, {"keep_dims", false}});
138+
runner.Run(stream);
139+
tmp_dout = &reduced_dout;
140+
}
141+
142+
// case 2: reduce axis of dout in which dim is 1
143+
// For example: [12, 140] --> [1, 140]
144+
145+
// case 3: copy dout to dx when shape is totally same, and dim in dx != 1
146+
// For example: [2, 10, 5] --> [2, 10, 5]
147+
axes.clear();
148+
for (auto i = 0; i < dx->dims().size(); ++i) {
149+
if (dx->dims()[i] == 1) {
150+
axes.push_back(i);
151+
}
152+
}
153+
if (axes.size() != 0) {
154+
const auto& runner = NpuOpRunner("ReduceSumD", {*tmp_dout}, {*dx},
155+
{{"axes", axes}, {"keep_dims", true}});
156+
runner.Run(stream);
157+
} else {
158+
framework::TensorCopySync(*tmp_dout, ctx.GetPlace(), dx);
159+
}
160+
}
161+
};
162+
100163
} // namespace operators
101164
} // namespace paddle
102165

@@ -107,3 +170,8 @@ REGISTER_OP_NPU_KERNEL(
107170
ops::ExpandV2NPUKernel<paddle::platform::NPUDeviceContext,
108171
paddle::platform::float16>,
109172
ops::ExpandV2NPUKernel<paddle::platform::NPUDeviceContext, int>);
173+
174+
REGISTER_OP_NPU_KERNEL(
175+
expand_v2_grad,
176+
ops::ExpandV2NPUGradKernel<paddle::platform::NPUDeviceContext, float>,
177+
ops::ExpandV2NPUGradKernel<paddle::platform::NPUDeviceContext, int>);

python/paddle/fluid/tests/unittests/npu/test_expand_v2_op_npu.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ def init_data(self):
5252
def test_check_output(self):
5353
self.check_output_with_place(self.place)
5454

55-
# def test_check_grad(self):
56-
# self.check_grad(['X'], 'Out')
55+
def test_check_grad(self):
56+
self.check_grad_with_place(self.place, ['X'], 'Out')
5757

5858

5959
class TestExpandV2OpRank2_DimExpanding(TestExpandV2NPUOpRank1):
@@ -118,8 +118,8 @@ def init_data(self):
118118
def test_check_output(self):
119119
self.check_output_with_place(self.place)
120120

121-
# def test_check_grad(self):
122-
# self.check_grad(['X'], 'Out')
121+
def test_check_grad(self):
122+
self.check_grad_with_place(self.place, ['X'], 'Out')
123123

124124

125125
class TestExpandV2OpRank2_Corner_tensor_attr(
@@ -159,11 +159,12 @@ def init_data(self):
159159
def test_check_output(self):
160160
self.check_output_with_place(self.place)
161161

162-
# def test_check_grad(self):
163-
# self.check_grad(['X'], 'Out')
162+
def test_check_grad(self):
163+
self.check_grad_with_place(self.place, ['X'], 'Out')
164164

165165

166-
# Situation 4: input x is float16
166+
# Situation 4: input x is float16
167+
# don't support grad check for float16
167168
class TestExpandV2OpInteger(OpTest):
168169
def setUp(self):
169170
self.set_npu()
@@ -184,6 +185,7 @@ def test_check_output(self):
184185

185186

186187
# Situation 5: input x is int32
188+
# ReduceSumD CANN Op doesn't support grad check for int32
187189
class TestExpandV2OpInteger(OpTest):
188190
def setUp(self):
189191
self.set_npu()
@@ -240,7 +242,7 @@ def test_static(self):
240242
out_2 = paddle.expand(x, shape=[positive_2, 14])
241243
out_3 = paddle.expand(x, shape=expand_shape)
242244

243-
# g0 = fluid.backward.calc_gradient(out_2, x)
245+
g0 = fluid.backward.calc_gradient(out_2, x)
244246

245247
exe = fluid.Executor(place=paddle.NPUPlace(0))
246248
res_1, res_2, res_3 = exe.run(fluid.default_main_program(),

0 commit comments

Comments
 (0)