Skip to content

Commit 8fa178a

Browse files
committed
add fused_attention test=develop
1 parent 6875626 commit 8fa178a

27 files changed

+1680
-148
lines changed

lite/api/paddle_use_passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,4 @@ USE_MIR_PASS(reshape_calc_offline_pass);
111111
USE_MIR_PASS(keepdims_convert_pass);
112112
USE_MIR_PASS(op_fusion_minimal_set_pass);
113113
USE_MIR_PASS(lite_sigmoid_elementmul_fuse_pass);
114+
USE_MIR_PASS(lite_ernie_attention_fuse_pass);

lite/backends/arm/math/conv_impl.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -735,6 +735,7 @@ void conv1x1s1_gemm_int8(const int8_t* i_data,
735735
n,
736736
k,
737737
flag_bias,
738+
GemmMBias,
738739
false,
739740
scale_group,
740741
act_param,
@@ -749,6 +750,7 @@ void conv1x1s1_gemm_int8(const int8_t* i_data,
749750
n,
750751
k,
751752
flag_bias,
753+
GemmMBias,
752754
false,
753755
scale_group,
754756
act_param,
@@ -1603,6 +1605,7 @@ void conv_im2col_gemm_int8(const int8_t* i_data,
16031605
n,
16041606
k,
16051607
flag_bias,
1608+
GemmMBias,
16061609
false,
16071610
scale_group,
16081611
act_param,
@@ -1617,6 +1620,7 @@ void conv_im2col_gemm_int8(const int8_t* i_data,
16171620
n,
16181621
k,
16191622
flag_bias,
1623+
GemmMBias,
16201624
false,
16211625
scale_group,
16221626
act_param,

lite/backends/arm/math/gemm_prepacked_int8.cc

Lines changed: 380 additions & 96 deletions
Large diffs are not rendered by default.

lite/backends/arm/math/gemm_prepacked_int8.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,20 @@ namespace arm {
2424
namespace math {
2525

2626
const int KBLOCK_INT8 = 4;
27+
typedef enum {
28+
GemmNoBias = 0,
29+
GemmMBias = 1,
30+
GemmNBias = 2,
31+
GemmMNBias = 3,
32+
} GemmBiasDirection;
33+
34+
typedef enum {
35+
GemmNoScale = 0,
36+
GemmMScale = 1,
37+
GemmNScale = 2,
38+
GemmMNScale = 3,
39+
} GemmScaleDirection;
40+
2741
#ifdef __aarch64__
2842
// for int7/int8 gemm
2943
// const int HBLOCK = 4;
@@ -94,6 +108,7 @@ void gemm_prepack_int8(const int8_t* A_packed,
94108
int N,
95109
int K,
96110
bool is_bias,
111+
GemmBiasDirection bias_direction,
97112
bool is_transB,
98113
const float* scale,
99114
const operators::ActivationParam act_param,

lite/backends/arm/math/gemm_s8.cc

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ void gemm_s8(bool is_transA,
3333
Dtype* C,
3434
const float* bias,
3535
bool is_bias,
36+
GemmBiasDirection bias_direction,
3637
const float* scale,
3738
const operators::ActivationParam act_param,
3839
ARMContext* ctx) {
@@ -83,8 +84,19 @@ void gemm_s8(bool is_transA,
8384
int lda = is_transA ? M : K;
8485
prepackA_int8(packed_A, A, lda, 0, M, 0, K, is_transA, ctx);
8586

86-
gemm_prepack_int8<Dtype>(
87-
packed_A, B, bias, C, M, N, K, is_bias, is_transB, scale, act_param, ctx);
87+
gemm_prepack_int8<Dtype>(packed_A,
88+
B,
89+
bias,
90+
C,
91+
M,
92+
N,
93+
K,
94+
is_bias,
95+
bias_direction,
96+
is_transB,
97+
scale,
98+
act_param,
99+
ctx);
88100
}
89101

90102
template void gemm_s8<float>(bool is_transA,
@@ -97,6 +109,7 @@ template void gemm_s8<float>(bool is_transA,
97109
float* C,
98110
const float* bias,
99111
bool is_bias,
112+
GemmBiasDirection bias_direction,
100113
const float* scale,
101114
const operators::ActivationParam act_param,
102115
ARMContext* ctx);
@@ -111,6 +124,7 @@ template void gemm_s8<int8_t>(bool is_transA,
111124
int8_t* C,
112125
const float* bias,
113126
bool is_bias,
127+
GemmBiasDirection bias_direction,
114128
const float* scale,
115129
const operators::ActivationParam act_param,
116130
ARMContext* ctx);
@@ -127,6 +141,7 @@ void gemm_sve(bool is_transA,
127141
Dtype* C,
128142
const float* bias,
129143
bool is_bias,
144+
GemmBiasDirection bias_direction,
130145
const float* scale,
131146
const operators::ActivationParam act_param,
132147
ARMContext* ctx) {
@@ -203,6 +218,7 @@ template void gemm_sve<float>(bool is_transA,
203218
float* C,
204219
const float* bias,
205220
bool is_bias,
221+
GemmBiasDirection bias_direction,
206222
const float* scale,
207223
const operators::ActivationParam act_param,
208224
ARMContext* ctx);
@@ -217,6 +233,7 @@ template void gemm_sve<int8_t>(bool is_transA,
217233
int8_t* C,
218234
const float* bias,
219235
bool is_bias,
236+
GemmBiasDirection bias_direction,
220237
const float* scale,
221238
const operators::ActivationParam act_param,
222239
ARMContext* ctx);

lite/backends/arm/math/gemm_s8.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ void gemm_s8(bool is_transA,
3535
Dtype* C,
3636
const float* bias,
3737
bool is_bias,
38+
GemmBiasDirection bias_direction,
3839
const float* scale,
3940
const operators::ActivationParam act_param,
4041
ARMContext* ctx);
@@ -51,6 +52,7 @@ void gemm_sve(bool is_transA,
5152
Dtype* C,
5253
const float* bias,
5354
bool is_bias,
55+
GemmBiasDirection bias_direction,
5456
const float* scale,
5557
const operators::ActivationParam act_param,
5658
ARMContext* ctx);

lite/backends/arm/math/gru_utils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,7 @@ struct GRUUnitFunctor {
502502
out_data.get(),
503503
nullptr,
504504
false,
505+
lite::arm::math::GemmNoBias,
505506
scales.data(),
506507
act_param,
507508
ctx);
@@ -550,6 +551,7 @@ struct GRUUnitFunctor {
550551
out_data.get(),
551552
nullptr,
552553
false,
554+
lite::arm::math::GemmNoBias,
553555
scales.data(),
554556
act_param,
555557
ctx);
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "lite/core/optimizer/mir/fusion/ernie_attention_fuse_pass.h"
16+
#include <list>
17+
#include <memory>
18+
#include <vector>
19+
#include "lite/core/optimizer/mir/fusion/ernie_attention_fuser.h"
20+
#include "lite/core/optimizer/mir/pass_registry.h"
21+
22+
namespace paddle {
23+
namespace lite {
24+
namespace mir {
25+
26+
void ErnieAttentionFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
27+
fusion::ErnieAttentionFuser fuser("matmul_v2");
28+
bool has_int8 = false;
29+
bool has_arm = false;
30+
bool has_opencl = false;
31+
for (auto& place : graph->valid_places()) {
32+
if (place.precision == PRECISION(kInt8)) {
33+
has_int8 = true;
34+
}
35+
if (place.target == TARGET(kARM)) {
36+
has_arm = true;
37+
}
38+
if (place.target == TARGET(kOpenCL)) {
39+
has_opencl = true;
40+
}
41+
}
42+
if ((has_arm && has_int8) || has_opencl) {
43+
fuser(graph.get());
44+
} else {
45+
return;
46+
}
47+
}
48+
49+
} // namespace mir
50+
} // namespace lite
51+
} // namespace paddle
52+
53+
REGISTER_MIR_PASS(lite_ernie_attention_fuse_pass,
54+
paddle::lite::mir::ErnieAttentionFusePass)
55+
.BindTargets({TARGET(kOpenCL), TARGET(kARM)})
56+
.BindKernel("fused_attention");
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <memory>
18+
#include <string>
19+
#include "lite/core/optimizer/mir/pass.h"
20+
21+
namespace paddle {
22+
namespace lite {
23+
namespace mir {
24+
25+
class ErnieAttentionFusePass : public ProgramPass {
26+
public:
27+
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
28+
};
29+
30+
} // namespace mir
31+
} // namespace lite
32+
} // namespace paddle

0 commit comments

Comments
 (0)