Skip to content

Commit 0adb29d

Browse files
authored
[x86] add x86 matmul_v2 (#9137)
1 parent 9729635 commit 0adb29d

4 files changed

Lines changed: 339 additions & 1 deletion

File tree

lite/kernels/x86/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ if(WITH_MKL)
6565
endif()
6666

6767
add_kernel(matmul_compute_x86 X86 basic SRCS matmul_compute.cc)
68+
add_kernel(matmul_v2_compute_x86 X86 basic SRCS matmul_v2_compute.cc)
6869
add_kernel(box_coder_compute_x86 X86 basic SRCS box_coder_compute.cc)
6970
add_kernel(density_prior_box_compute_x86 X86 basic SRCS density_prior_box_compute.cc)
7071
add_kernel(interpolate_compute_x86 X86 basic SRCS interpolate_compute.cc)
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "lite/kernels/x86/matmul_v2_compute.h"
16+
17+
REGISTER_LITE_KERNEL(matmul_v2,
18+
kX86,
19+
kFloat,
20+
kNCHW,
21+
paddle::lite::kernels::x86::MatMulV2Compute<float>,
22+
def)
23+
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
24+
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86))})
25+
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
26+
.Finalize();
Lines changed: 305 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,305 @@
1+
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
#pragma once
15+
16+
#include "lite/backends/x86/math/blas.h"
17+
#include "lite/core/kernel.h"
18+
#include "lite/core/op_registry.h"
19+
#include "lite/core/types.h"
20+
namespace paddle {
21+
namespace lite {
22+
namespace kernels {
23+
namespace x86 {
24+
25+
#define INIT_PARAM \
26+
auto& ctx = this->ctx_->template As<X86Context>(); \
27+
auto& param = *param_.get_mutable<operators::MatMulParam>(); \
28+
auto x_dims = param.X->dims(); \
29+
auto y_dims = param.Y->dims(); \
30+
int m, n, k; \
31+
int lda, ldb, ldc; \
32+
bool x_transpose = param.transpose_X; \
33+
bool y_transpose = param.transpose_Y; \
34+
if ((x_dims.size() >= 2 && y_dims.size() >= 2) && \
35+
(x_dims.size() != 2 || y_dims.size() != 2)) { \
36+
if (!x_transpose) { \
37+
m = x_dims[x_dims.size() - 2]; \
38+
k = x_dims[x_dims.size() - 1]; \
39+
lda = k; \
40+
} else { \
41+
m = x_dims[x_dims.size() - 1]; \
42+
k = x_dims[x_dims.size() - 2]; \
43+
lda = m; \
44+
} \
45+
if (!y_transpose) { \
46+
n = y_dims[y_dims.size() - 1]; \
47+
ldb = n; \
48+
CHECK_EQ(k, y_dims[y_dims.size() - 2]) \
49+
<< "k must be equal y_dims[y_dims.size() - 2]"; \
50+
} else { \
51+
n = y_dims[y_dims.size() - 2]; \
52+
ldb = k; \
53+
CHECK_EQ(k, y_dims[y_dims.size() - 1]) \
54+
<< "k must be equal y_dims[y_dims.size() - 1]"; \
55+
} \
56+
ldc = n; \
57+
if (x_dims.size() > 2 && y_dims.size() > 2) { \
58+
auto sum_x = x_dims.count(0, x_dims.size() - 2); \
59+
auto sum_y = y_dims.count(0, y_dims.size() - 2); \
60+
CHECK_EQ(sum_x, sum_y) \
61+
<< "sum_x(x_dims[0]+..x_dims[size()-2]) must be equal with " \
62+
"sum_y(y_dims[0]+..y_dims[size()-2])"; \
63+
} \
64+
} else if ((x_dims.size() == 2 && y_dims.size() == 2) || \
65+
(x_dims.size() == 2 && y_dims.size() == 1)) { \
66+
if (!x_transpose) { \
67+
m = x_dims[0]; \
68+
k = x_dims[1]; \
69+
lda = k; \
70+
} else { \
71+
m = x_dims[1]; \
72+
k = x_dims[0]; \
73+
lda = m; \
74+
} \
75+
if (!y_transpose) { \
76+
if (y_dims.size() > 1) { \
77+
n = y_dims[1]; \
78+
} else { \
79+
n = 1; \
80+
} \
81+
ldb = n; \
82+
CHECK_EQ(k, y_dims[0]) << "k must be equal y_dims[0]"; \
83+
} else { \
84+
if (y_dims.size() > 1) { \
85+
n = y_dims[0]; \
86+
CHECK_EQ(k, y_dims[1]) << "k must be equal y_dims[1]"; \
87+
} else { \
88+
n = 1; \
89+
CHECK_EQ(k, y_dims[0]) << "k must be equal y_dims[0]"; \
90+
} \
91+
ldb = k; \
92+
} \
93+
ldc = n; \
94+
} else if (x_dims.size() >= 2 && y_dims.size() == 1) { \
95+
n = 1; \
96+
k = y_dims[0]; \
97+
if (!x_transpose) { \
98+
m = x_dims.count(0, x_dims.size() - 1); \
99+
CHECK_EQ(k, x_dims[x_dims.size() - 1]) \
100+
<< "k must be equal x_dims[x_dims.size() - 1]"; \
101+
} else { \
102+
m = x_dims.count(1, x_dims.size() - 1); \
103+
CHECK_EQ(k, x_dims[0]) << "k must be equal x_dims[0]"; \
104+
} \
105+
lda = k; \
106+
ldb = n; \
107+
ldc = n; \
108+
} else if (y_dims.size() >= 2 && x_dims.size() == 1) { \
109+
m = 1; \
110+
k = x_dims[0]; \
111+
if (!y_transpose) { \
112+
n = y_dims.count(1, y_dims.size()); \
113+
CHECK_EQ(k, y_dims[0]) << "k must be equal y_dims[0]"; \
114+
} else { \
115+
n = y_dims.count(0, y_dims.size() - 1); \
116+
CHECK_EQ(k, y_dims[y_dims.size() - 1]) \
117+
<< "k must be equal y_dims[y_dims.size() - 1]"; \
118+
} \
119+
lda = k; \
120+
ldb = n; \
121+
ldc = n; \
122+
} else if (x_dims.size() == 1 && y_dims.size() == 1) { \
123+
m = 1; \
124+
n = 1; \
125+
k = x_dims[0]; \
126+
if (x_transpose == true && y_transpose == true) { \
127+
m = x_dims[0]; \
128+
k = 1; \
129+
n = y_dims[0]; \
130+
} else if (x_transpose == false && y_transpose == false) { \
131+
CHECK_EQ(x_dims[0], y_dims[0]) << "x_dims[0] must be equal y_dims[0]"; \
132+
} else { \
133+
LOG(FATAL) << "not supported x_dims(" << x_dims << ") and y_dims(" \
134+
<< y_dims << ")" \
135+
<< ", when x_transpose is " << x_transpose \
136+
<< " and y_transpose is " << y_transpose; \
137+
} \
138+
lda = k; \
139+
ldb = n; \
140+
ldc = n; \
141+
} else { \
142+
LOG(FATAL) << "This x_dims: " << x_dims << " and y_dims: " << y_dims \
143+
<< " doesn't support!"; \
144+
}
145+
146+
template <typename T>
147+
class MatMulV2Compute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
148+
public:
149+
using param_t = operators::MatMulParam;
150+
151+
void Run() override {
152+
INIT_PARAM;
153+
const auto* x_data = param.X->template data<T>();
154+
const auto* y_data = param.Y->template data<T>();
155+
auto* o_data = param.Out->template mutable_data<T>();
156+
auto o_dims = param.Out->dims();
157+
auto alpha = param.alpha;
158+
159+
auto blas = lite::x86::math::GetBlas<lite::TargetType::kX86, T>(ctx);
160+
161+
if ((x_dims.size() >= 2 && y_dims.size() >= 2) &&
162+
(x_dims.size() != 2 || y_dims.size() != 2)) {
163+
// x: [B, ..., M, K], y: [B, ..., K, N], out: [B, ..., M, N]
164+
// x: [B, M, K], y: [K, N], out: [B, M, N]
165+
// or
166+
// x: [M, K], y: [B, ..., K, N], out: [B, ..., M, N]
167+
// x: [M, K], y: [B, K, N], out: [B, M, N]
168+
int x_inner = x_dims[x_dims.size() - 2] * x_dims[x_dims.size() - 1];
169+
int y_inner = y_dims[y_dims.size() - 2] * y_dims[y_dims.size() - 1];
170+
int out_inner = o_dims[o_dims.size() - 2] * o_dims[o_dims.size() - 1];
171+
172+
if (x_dims.size() > 2 && y_dims.size() > 2) {
173+
for (size_t i = 0; i < x_dims.count(0, x_dims.size() - 2); ++i) {
174+
blas.GEMM(x_transpose,
175+
y_transpose,
176+
m,
177+
n,
178+
k,
179+
alpha,
180+
x_data + i * x_inner,
181+
lda,
182+
y_data + i * y_inner,
183+
ldb,
184+
0.f,
185+
o_data + i * out_inner,
186+
ldc);
187+
}
188+
} else if (x_dims.size() > 2 && y_dims.size() == 2) {
189+
for (size_t i = 0; i < x_dims.count(0, x_dims.size() - 2); ++i) {
190+
blas.GEMM(x_transpose,
191+
y_transpose,
192+
m,
193+
n,
194+
k,
195+
alpha,
196+
x_data + i * x_inner,
197+
lda,
198+
y_data,
199+
ldb,
200+
0.f,
201+
o_data + i * out_inner,
202+
ldc);
203+
}
204+
} else if (x_dims.size() == 2 && y_dims.size() > 2) {
205+
for (size_t i = 0; i < y_dims.count(0, y_dims.size() - 2); ++i) {
206+
blas.GEMM(x_transpose,
207+
y_transpose,
208+
m,
209+
n,
210+
k,
211+
alpha,
212+
x_data,
213+
lda,
214+
y_data + i * y_inner,
215+
ldb,
216+
0.f,
217+
o_data + i * out_inner,
218+
ldc);
219+
}
220+
}
221+
} else if (x_dims.size() == 2 && y_dims.size() == 2) {
222+
// x: [M, K], y: [K, N], out: [M, N]
223+
blas.GEMM(x_transpose,
224+
y_transpose,
225+
m,
226+
n,
227+
k,
228+
alpha,
229+
x_data,
230+
lda,
231+
y_data,
232+
ldb,
233+
0.f,
234+
o_data,
235+
ldc);
236+
} else if (x_dims.size() >= 2 && y_dims.size() == 1) {
237+
// x: [B, M, K], y: [K], out: [B, M]
238+
blas.GEMM(x_transpose,
239+
false,
240+
m,
241+
n,
242+
k,
243+
alpha,
244+
x_data,
245+
lda,
246+
y_data,
247+
ldb,
248+
0.f,
249+
o_data,
250+
ldc);
251+
} else if (y_dims.size() >= 2 && x_dims.size() == 1) {
252+
// y: [B, K, N], x: [K], out: [B, N]
253+
blas.GEMM(false,
254+
y_transpose,
255+
m,
256+
n,
257+
k,
258+
alpha,
259+
x_data,
260+
lda,
261+
y_data,
262+
ldb,
263+
0.f,
264+
o_data,
265+
ldc);
266+
} else if (x_dims.size() == 1 && y_dims.size() == 1) {
267+
// x: [K], y: [K], out: [1]
268+
if (x_transpose == false && y_transpose == false) {
269+
o_data[0] = 0.;
270+
for (size_t i = 0; i < x_dims[0]; ++i) {
271+
o_data[0] += x_data[i] * y_data[i] * alpha;
272+
}
273+
} else if (x_transpose == true && y_transpose == true) {
274+
blas.GEMM(false,
275+
false,
276+
m,
277+
n,
278+
k,
279+
alpha,
280+
x_data,
281+
lda,
282+
y_data,
283+
ldb,
284+
0.f,
285+
o_data,
286+
ldc);
287+
} else {
288+
LOG(FATAL) << "not supported x_dims.(" << x_dims << ") and y_dims("
289+
<< y_dims << ")"
290+
<< ", and x_transpose: " << x_transpose
291+
<< ", y_transpose: " << y_transpose;
292+
}
293+
} else {
294+
LOG(FATAL) << "not supported x_dims(" << x_dims << ") and y_dims("
295+
<< y_dims << ")";
296+
}
297+
}
298+
299+
virtual ~MatMulV2Compute() = default;
300+
};
301+
302+
} // namespace x86
303+
} // namespace kernels
304+
} // namespace lite
305+
} // namespace paddle

lite/tests/unittest_py/op/test_matmul_v2_op.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,11 @@ def __init__(self, *args, **kwargs):
5757
Place(TargetType.ARM, PrecisionType.FP32),
5858
Place(TargetType.Host, PrecisionType.FP32)
5959
]
60+
self.enable_testing_on_place(
61+
TargetType.X86,
62+
PrecisionType.FP32,
63+
DataLayoutType.NCHW,
64+
thread=[1, 4])
6065
self.enable_testing_on_place(places=opencl_places)
6166
self.enable_testing_on_place(TargetType.NNAdapter, PrecisionType.FP32)
6267
self.enable_devices_on_nnadapter(
@@ -253,7 +258,8 @@ def _teller1(program_config, predictor_config):
253258
def _teller2(program_config, predictor_config):
254259
x_shape = list(program_config.inputs["input_data_x"].shape)
255260
transpose_X = program_config.ops[0].attrs["trans_x"]
256-
if predictor_config.target() == TargetType.ARM:
261+
if ((predictor_config.target() == TargetType.ARM) or
262+
(predictor_config.target() == TargetType.X86)):
257263
y_shape = list(program_config.inputs["input_data_y"].shape)
258264
if len(x_shape) == 1 and len(
259265
y_shape) == 1 and transpose_X == True:

0 commit comments

Comments
 (0)