Skip to content

Commit 925432d

Browse files
authored
【NPU】Support npu kernel for mul op (#31584)
* add mul * add test mul
1 parent 1e95600 commit 925432d

File tree

2 files changed

+569
-0
lines changed

2 files changed

+569
-0
lines changed
Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <memory>
16+
#include <string>
17+
18+
#include "paddle/fluid/operators/mul_op.h"
19+
#include "paddle/fluid/operators/npu_op_runner.h"
20+
21+
namespace paddle {
22+
namespace operators {
23+
24+
template <typename DeviceContext, typename T>
25+
class MulNPUKernel : public framework::OpKernel<T> {
26+
public:
27+
void Compute(const framework::ExecutionContext& ctx) const override {
28+
auto* x = ctx.Input<framework::Tensor>("X");
29+
auto* y = ctx.Input<framework::Tensor>("Y");
30+
auto* out = ctx.Output<framework::Tensor>("Out");
31+
int x_num_col_dims = ctx.Attr<int>("x_num_col_dims");
32+
int y_num_col_dims = ctx.Attr<int>("y_num_col_dims");
33+
auto stream =
34+
ctx.template device_context<paddle::platform::NPUDeviceContext>()
35+
.stream();
36+
if (x_num_col_dims == 1 && y_num_col_dims == 1) {
37+
if (x->dims().size() == 2 && y->dims().size() == 2) {
38+
out->mutable_data<T>(ctx.GetPlace());
39+
auto runner =
40+
NpuOpRunner("MatMul", {*x, *y}, {*out},
41+
{{"transpose_x1", false}, {"transpose_x2", false}});
42+
43+
runner.Run(stream);
44+
} else if (x->dims().size() == 3 && y->dims().size() == 2) {
45+
// reshape
46+
Tensor tmp_x(x->type());
47+
int64_t sec_dim = x->dims()[1] * x->dims()[2];
48+
int64_t first_dim = x->dims()[0];
49+
tmp_x.Resize(framework::make_ddim({first_dim, sec_dim}));
50+
tmp_x.mutable_data<T>(ctx.GetPlace());
51+
framework::TensorCopy(
52+
*x, ctx.GetPlace(),
53+
ctx.template device_context<platform::DeviceContext>(), &tmp_x);
54+
tmp_x.Resize(framework::make_ddim({first_dim, sec_dim}));
55+
out->mutable_data<T>(ctx.GetPlace());
56+
// matmul
57+
auto runner =
58+
NpuOpRunner("MatMul", {tmp_x, *y}, {*out},
59+
{{"transpose_x1", false}, {"transpose_x2", false}});
60+
runner.Run(stream);
61+
} else {
62+
PADDLE_THROW(platform::errors::InvalidArgument("not suppert dims"));
63+
}
64+
// to do other
65+
} else if (x->dims().size() == 3 && y->dims().size() == 2) {
66+
// for example: x.shape=[2, 3, 4] y.shape=[4, 5], expect [2, 3, 5]
67+
PADDLE_ENFORCE_EQ(x_num_col_dims, 2,
68+
platform::errors::InvalidArgument(
69+
"now only support x_num_col_dims == 2: but got %d",
70+
x_num_col_dims));
71+
// flatten => x.shape=[6, 4]
72+
Tensor tmp_x(x->type());
73+
int64_t first_dim = x->dims()[0] * x->dims()[1];
74+
int64_t sec_dim = x->dims()[2];
75+
tmp_x.Resize(framework::make_ddim({first_dim, sec_dim}));
76+
tmp_x.mutable_data<T>(ctx.GetPlace());
77+
framework::TensorCopy(
78+
*x, ctx.GetPlace(),
79+
ctx.template device_context<platform::DeviceContext>(), &tmp_x);
80+
tmp_x.Resize(framework::make_ddim({first_dim, sec_dim}));
81+
82+
// matmul [6,4] , [4, 5] => [6, 5]
83+
Tensor tmp_matmul(x->type());
84+
tmp_matmul.Resize(framework::make_ddim({first_dim, y->dims()[1]}));
85+
tmp_matmul.mutable_data<T>(ctx.GetPlace());
86+
87+
auto runner_matmul =
88+
NpuOpRunner("MatMul", {tmp_x, *y}, {tmp_matmul},
89+
{{"transpose_x1", false}, {"transpose_x2", false}});
90+
91+
runner_matmul.Run(stream);
92+
// reshape [6, 5] => [2, 3, 5]
93+
(*out).Resize(
94+
framework::make_ddim({x->dims()[0], x->dims()[1], y->dims()[1]}));
95+
out->mutable_data(ctx.GetPlace(), x->type());
96+
framework::TensorCopy(
97+
tmp_matmul, ctx.GetPlace(),
98+
ctx.template device_context<platform::DeviceContext>(), out);
99+
(*out).Resize(
100+
framework::make_ddim({x->dims()[0], x->dims()[1], y->dims()[1]}));
101+
}
102+
}
103+
};
104+
105+
template <typename DeviceContext, typename T>
106+
class MulGradNPUKernel : public framework::OpKernel<T> {
107+
public:
108+
void Compute(const framework::ExecutionContext& ctx) const override {
109+
auto* x = ctx.Input<framework::Tensor>("X");
110+
auto* y = ctx.Input<framework::Tensor>("Y");
111+
auto* dout = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
112+
auto* dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
113+
auto* dy = ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
114+
int x_num_col_dims = ctx.Attr<int>("x_num_col_dims");
115+
int y_num_col_dims = ctx.Attr<int>("y_num_col_dims");
116+
auto stream =
117+
ctx.template device_context<paddle::platform::NPUDeviceContext>()
118+
.stream();
119+
if (x_num_col_dims == 1 && y_num_col_dims == 1) {
120+
if (x->dims().size() == 2 && y->dims().size() == 2) {
121+
if (dx) {
122+
dx->mutable_data<T>(ctx.GetPlace());
123+
auto runner_dx =
124+
NpuOpRunner("MatMul", {*dout, *y}, {*dx},
125+
{{"transpose_x1", false}, {"transpose_x2", true}});
126+
127+
runner_dx.Run(stream);
128+
}
129+
130+
if (dy) {
131+
dy->mutable_data<T>(ctx.GetPlace());
132+
auto runner_dy =
133+
NpuOpRunner("MatMul", {*x, *dout}, {*dy},
134+
{{"transpose_x1", true}, {"transpose_x2", false}});
135+
136+
runner_dy.Run(stream);
137+
}
138+
} else if (x->dims().size() == 3 && y->dims().size() == 2) {
139+
// flatten => x.shape=[6, 4]
140+
// matmul
141+
if (dx) {
142+
// matmul [2, 5] * [12, 5] => [2, 12]
143+
Tensor tmp_matmul(y->type());
144+
tmp_matmul.Resize(
145+
framework::make_ddim({dout->dims()[0], y->dims()[0]}));
146+
tmp_matmul.mutable_data<T>(ctx.GetPlace());
147+
auto runner_matmul =
148+
NpuOpRunner("MatMul", {*dout, *y}, {tmp_matmul},
149+
{{"transpose_x1", false}, {"transpose_x2", true}});
150+
runner_matmul.Run(stream);
151+
// reshape [2, 12] => [2, 3, 4]
152+
dx->mutable_data(ctx.GetPlace(), x->type());
153+
framework::TensorCopy(
154+
tmp_matmul, ctx.GetPlace(),
155+
ctx.template device_context<platform::DeviceContext>(), dx);
156+
}
157+
158+
if (dy) {
159+
// flatten
160+
Tensor tmp_x(x->type());
161+
int64_t sec_dim = x->dims()[1] * x->dims()[2];
162+
int64_t first_dim = x->dims()[0];
163+
tmp_x.Resize(framework::make_ddim({first_dim, sec_dim}));
164+
tmp_x.mutable_data<T>(ctx.GetPlace());
165+
framework::TensorCopy(
166+
*x, ctx.GetPlace(),
167+
ctx.template device_context<platform::DeviceContext>(), &tmp_x);
168+
tmp_x.Resize(framework::make_ddim({first_dim, sec_dim}));
169+
dy->mutable_data<T>(ctx.GetPlace());
170+
auto runner_dy =
171+
NpuOpRunner("MatMul", {tmp_x, *dout}, {*dy},
172+
{{"transpose_x1", true}, {"transpose_x2", false}});
173+
174+
runner_dy.Run(stream);
175+
}
176+
}
177+
} else if (x->dims().size() == 3 && y->dims().size() == 2) {
178+
// for example: x.shape=[2, 3, 4] y.shape=[4, 5], expect [2, 3, 5]
179+
PADDLE_ENFORCE_EQ(x_num_col_dims, 2,
180+
platform::errors::InvalidArgument(
181+
"now only support x_num_col_dims == 2: but got %d",
182+
x_num_col_dims));
183+
// tmp_dout both used by dx and dy
184+
Tensor tmp_dout(x->type());
185+
int64_t dout_first_dim = dout->dims()[0] * dout->dims()[1];
186+
int64_t dout_sec_dim = dout->dims()[2];
187+
tmp_dout.Resize(framework::make_ddim({dout_first_dim, dout_sec_dim}));
188+
tmp_dout.mutable_data<T>(ctx.GetPlace());
189+
framework::TensorCopy(
190+
*dout, ctx.GetPlace(),
191+
ctx.template device_context<platform::DeviceContext>(), &tmp_dout);
192+
tmp_dout.Resize(framework::make_ddim({dout_first_dim, dout_sec_dim}));
193+
194+
if (dx) {
195+
// tmp_dout * y [6,5] * [4,5] => [6, 4]
196+
Tensor tmp_matmul(y->type());
197+
tmp_matmul.Resize(framework::make_ddim({dout_first_dim, y->dims()[0]}));
198+
tmp_matmul.mutable_data<T>(ctx.GetPlace());
199+
auto runner_matmul =
200+
NpuOpRunner("MatMul", {tmp_dout, *y}, {tmp_matmul},
201+
{{"transpose_x1", false}, {"transpose_x2", true}});
202+
runner_matmul.Run(stream);
203+
// reshape [6,4] => [2, 3, 4]
204+
dx->mutable_data(ctx.GetPlace(), x->type());
205+
framework::TensorCopy(
206+
tmp_matmul, ctx.GetPlace(),
207+
ctx.template device_context<platform::DeviceContext>(), dx);
208+
}
209+
if (dy) {
210+
// flatten x.shape [2,3,4] => [6, 4]
211+
Tensor tmp_x(x->type());
212+
int64_t first_dim = x->dims()[0] * x->dims()[1];
213+
int64_t sec_dim = x->dims()[2];
214+
tmp_x.Resize(framework::make_ddim({first_dim, sec_dim}));
215+
tmp_x.mutable_data<T>(ctx.GetPlace());
216+
framework::TensorCopy(
217+
*x, ctx.GetPlace(),
218+
ctx.template device_context<platform::DeviceContext>(), &tmp_x);
219+
tmp_x.Resize(framework::make_ddim({first_dim, sec_dim}));
220+
// mamtul [6,4] [6,5] =>[4,5]
221+
dy->mutable_data<T>(ctx.GetPlace());
222+
auto runner_dy =
223+
NpuOpRunner("MatMul", {tmp_x, tmp_dout}, {*dy},
224+
{{"transpose_x1", true}, {"transpose_x2", false}});
225+
runner_dy.Run(stream);
226+
}
227+
}
228+
}
229+
};
230+
231+
} // namespace operators
232+
} // namespace paddle
233+
234+
namespace ops = paddle::operators;
235+
236+
REGISTER_OP_NPU_KERNEL(
237+
mul, ops::MulNPUKernel<paddle::platform::NPUDeviceContext, float>,
238+
ops::MulNPUKernel<paddle::platform::NPUDeviceContext,
239+
paddle::platform::float16>);
240+
REGISTER_OP_NPU_KERNEL(
241+
mul_grad, ops::MulGradNPUKernel<paddle::platform::NPUDeviceContext, float>,
242+
ops::MulGradNPUKernel<paddle::platform::NPUDeviceContext,
243+
paddle::platform::float16>);

0 commit comments

Comments
 (0)