Skip to content
Merged
1 change: 1 addition & 0 deletions lite/api/paddle_use_passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,4 @@ USE_MIR_PASS(reshape_calc_offline_pass);
USE_MIR_PASS(keepdims_convert_pass);
USE_MIR_PASS(op_fusion_minimal_set_pass);
USE_MIR_PASS(lite_sigmoid_elementmul_fuse_pass);
USE_MIR_PASS(transformer_attention_fuse_pass);
4 changes: 4 additions & 0 deletions lite/backends/arm/math/conv_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -735,6 +735,7 @@ void conv1x1s1_gemm_int8(const int8_t* i_data,
n,
k,
flag_bias,
GemmMBias,
false,
scale_group,
act_param,
Expand All @@ -749,6 +750,7 @@ void conv1x1s1_gemm_int8(const int8_t* i_data,
n,
k,
flag_bias,
GemmMBias,
false,
scale_group,
act_param,
Expand Down Expand Up @@ -1603,6 +1605,7 @@ void conv_im2col_gemm_int8(const int8_t* i_data,
n,
k,
flag_bias,
GemmMBias,
false,
scale_group,
act_param,
Expand All @@ -1617,6 +1620,7 @@ void conv_im2col_gemm_int8(const int8_t* i_data,
n,
k,
flag_bias,
GemmMBias,
false,
scale_group,
act_param,
Expand Down
752 changes: 602 additions & 150 deletions lite/backends/arm/math/gemm_prepacked_int8.cc

Large diffs are not rendered by default.

15 changes: 15 additions & 0 deletions lite/backends/arm/math/gemm_prepacked_int8.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,20 @@ namespace arm {
namespace math {

const int KBLOCK_INT8 = 4;
typedef enum {
GemmNoBias = 0,
GemmMBias = 1,
GemmNBias = 2,
GemmMNBias = 3,
} GemmBiasDirection;

typedef enum {
GemmNoScale = 0,
GemmMScale = 1,
GemmNScale = 2,
GemmMNScale = 3,
} GemmScaleDirection;

#ifdef __aarch64__
// for int7/int8 gemm
// const int HBLOCK = 4;
Expand Down Expand Up @@ -94,6 +108,7 @@ void gemm_prepack_int8(const int8_t* A_packed,
int N,
int K,
bool is_bias,
GemmBiasDirection bias_direction,
bool is_transB,
const float* scale,
const operators::ActivationParam act_param,
Expand Down
21 changes: 19 additions & 2 deletions lite/backends/arm/math/gemm_s8.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ void gemm_s8(bool is_transA,
Dtype* C,
const float* bias,
bool is_bias,
GemmBiasDirection bias_direction,
const float* scale,
const operators::ActivationParam act_param,
ARMContext* ctx) {
Expand Down Expand Up @@ -83,8 +84,19 @@ void gemm_s8(bool is_transA,
int lda = is_transA ? M : K;
prepackA_int8(packed_A, A, lda, 0, M, 0, K, is_transA, ctx);

gemm_prepack_int8<Dtype>(
packed_A, B, bias, C, M, N, K, is_bias, is_transB, scale, act_param, ctx);
gemm_prepack_int8<Dtype>(packed_A,
B,
bias,
C,
M,
N,
K,
is_bias,
bias_direction,
is_transB,
scale,
act_param,
ctx);
}

template void gemm_s8<float>(bool is_transA,
Expand All @@ -97,6 +109,7 @@ template void gemm_s8<float>(bool is_transA,
float* C,
const float* bias,
bool is_bias,
GemmBiasDirection bias_direction,
const float* scale,
const operators::ActivationParam act_param,
ARMContext* ctx);
Expand All @@ -111,6 +124,7 @@ template void gemm_s8<int8_t>(bool is_transA,
int8_t* C,
const float* bias,
bool is_bias,
GemmBiasDirection bias_direction,
const float* scale,
const operators::ActivationParam act_param,
ARMContext* ctx);
Expand All @@ -127,6 +141,7 @@ void gemm_sve(bool is_transA,
Dtype* C,
const float* bias,
bool is_bias,
GemmBiasDirection bias_direction,
const float* scale,
const operators::ActivationParam act_param,
ARMContext* ctx) {
Expand Down Expand Up @@ -203,6 +218,7 @@ template void gemm_sve<float>(bool is_transA,
float* C,
const float* bias,
bool is_bias,
GemmBiasDirection bias_direction,
const float* scale,
const operators::ActivationParam act_param,
ARMContext* ctx);
Expand All @@ -217,6 +233,7 @@ template void gemm_sve<int8_t>(bool is_transA,
int8_t* C,
const float* bias,
bool is_bias,
GemmBiasDirection bias_direction,
const float* scale,
const operators::ActivationParam act_param,
ARMContext* ctx);
Expand Down
2 changes: 2 additions & 0 deletions lite/backends/arm/math/gemm_s8.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ void gemm_s8(bool is_transA,
Dtype* C,
const float* bias,
bool is_bias,
GemmBiasDirection bias_direction,
const float* scale,
const operators::ActivationParam act_param,
ARMContext* ctx);
Expand All @@ -51,6 +52,7 @@ void gemm_sve(bool is_transA,
Dtype* C,
const float* bias,
bool is_bias,
GemmBiasDirection bias_direction,
const float* scale,
const operators::ActivationParam act_param,
ARMContext* ctx);
Expand Down
2 changes: 2 additions & 0 deletions lite/backends/arm/math/gru_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,7 @@ struct GRUUnitFunctor {
out_data.get(),
nullptr,
false,
lite::arm::math::GemmNoBias,
scales.data(),
act_param,
ctx);
Expand Down Expand Up @@ -550,6 +551,7 @@ struct GRUUnitFunctor {
out_data.get(),
nullptr,
false,
lite::arm::math::GemmNoBias,
scales.data(),
act_param,
ctx);
Expand Down
3 changes: 1 addition & 2 deletions lite/core/optimizer/mir/fusion/fc_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ void FcFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
}
}
}
if (!(has_int8 && has_weight_quant) && has_arm && !is_nnadapter) {
// only support FP32/FP16
if (has_arm && !is_nnadapter) {
mul_types.push_back("matmul");
mul_types.push_back("matmul_v2");
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/core/optimizer/mir/fusion/transformer_attention_fuse_pass.h"
#include <list>
#include <memory>
#include <vector>
#include "lite/core/optimizer/mir/fusion/transformer_attention_fuser.h"
#include "lite/core/optimizer/mir/pass_registry.h"

namespace paddle {
namespace lite {
namespace mir {

void TransformerAttentionFusePass::Apply(
const std::unique_ptr<SSAGraph>& graph) {
fusion::TransformerAttentionFuser fuser;
bool has_int8 = false;
for (auto& place : graph->valid_places()) {
if (place.precision == PRECISION(kInt8)) {
has_int8 = true;
}
}
if ((has_int8)) {
fuser(graph.get());
} else {
return;
}
Comment on lines +37 to +39
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不需要else

}

} // namespace mir
} // namespace lite
} // namespace paddle

REGISTER_MIR_PASS(transformer_attention_fuse_pass,
paddle::lite::mir::TransformerAttentionFusePass)
.BindTargets({TARGET(kARM)})
.ExcludeTargets(
{TARGET(kXPU), TARGET(kOpenCL), TARGET(kMetal), TARGET(kNNAdapter)})
Comment on lines +48 to +50
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
.BindTargets({TARGET(kARM)})
.ExcludeTargets(
{TARGET(kXPU), TARGET(kOpenCL), TARGET(kMetal), TARGET(kNNAdapter)})
.ExcludeTargets({TARGET(kAny))
.BindTargets({TARGET(kARM)})

.BindKernel("fused_attention");
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <memory>
#include <string>
#include "lite/core/optimizer/mir/pass.h"

namespace paddle {
namespace lite {
namespace mir {

class TransformerAttentionFusePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
};

} // namespace mir
} // namespace lite
} // namespace paddle
Loading