Skip to content

Commit f6cef4e

Browse files
author
shenyijun01
committed
Squashed commit of the following:
commit 774928b74a784ad2eb24490628d9e70a742ced58 Author: shenyijun01 <[email protected]> Date: Fri Nov 18 13:35:30 2022 +0800 [Optimizer]: add quick gelu fusion pass for ViT model.
1 parent eff8dc1 commit f6cef4e

13 files changed

Lines changed: 322 additions & 2 deletions

lite/api/paddle_use_passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ USE_MIR_PASS(__xpu__resnet_fuse_pass);
8282
USE_MIR_PASS(__xpu__multi_encoder_fuse_pass);
8383
USE_MIR_PASS(__xpu__embedding_with_eltwise_add_fuse_pass);
8484
USE_MIR_PASS(__xpu__fc_fuse_pass);
85+
USE_MIR_PASS(__xpu__quick_gelu_fuse_pass);
8586
USE_MIR_PASS(__xpu__mmdnn_fuse_pass);
8687
USE_MIR_PASS(__xpu__conv2d_affine_channel_fuse_pass);
8788
USE_MIR_PASS(__xpu__conv2d_fuse_pass);

lite/core/optimizer/mir/fusion/__xpu__fc_fuse_pass.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,8 @@ class XPUFcFuser : public FuseBase {
199199
{"leaky_relu", 5},
200200
{"hard_swish", 14},
201201
{"hard_sigmoid", 15},
202-
{"relu6", 17}};
202+
{"relu6", 17},
203+
{"__xpu__quick_gelu", 19}};
203204

204205
float act_param_ = 0.0f;
205206
if (act_type_ == "leaky_relu") {
@@ -281,6 +282,7 @@ class XPUFcFusePass : public ProgramPass {
281282
for (auto with_bias : {true, false}) {
282283
for (auto act_type : {"relu",
283284
"gelu",
285+
"__xpu__quick_gelu",
284286
/*"sigmoid",
285287
"tanh",
286288
"leaky_relu",

lite/core/optimizer/mir/fusion/__xpu__multi_encoder_fuse_pass.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1422,7 +1422,7 @@ class XPUMultiEncoderFusePass : public ProgramPass {
14221422
void Apply(const std::unique_ptr<SSAGraph>& graph) override {
14231423
if (GetBoolFromEnv("XPU_ENABLE_XTCL")) return;
14241424
// TODO(miaotianxiang): backup graph, recover from failed match
1425-
std::vector<std::string> act_types{"gelu", "relu"};
1425+
std::vector<std::string> act_types{"gelu", "relu", "__xpu__quick_gelu"};
14261426
std::vector<std::string> input_poss{"X", "Y"};
14271427
std::vector<std::string> qkv_ln_2_out_poss{"X", "Y"};
14281428
std::vector<std::string> matmul_types{"matmul", "matmul_v2"};
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <math.h>
16+
#include <memory>
17+
#include <string>
18+
#include "lite/backends/xpu/math.h"
19+
#include "lite/core/optimizer/mir/pass_registry.h"
20+
#include "lite/core/optimizer/mir/pattern_matcher_high_api.h"
21+
22+
namespace paddle {
23+
namespace lite {
24+
namespace mir {
25+
namespace fusion {
26+
27+
class XPUQuickGELUFuser : public FuseBase {
28+
public:
29+
XPUQuickGELUFuser() {}
30+
31+
void BuildPattern() override {
32+
auto scale_teller = [](const Node* node) -> bool {
33+
float bias_v =
34+
const_cast<Node*>(node)->AsStmt().op_info()->GetAttr<float>("bias");
35+
float scale_v =
36+
const_cast<Node*>(node)->AsStmt().op_info()->GetAttr<float>("scale");
37+
bool expect_bias = (bias_v == 0.0) ? true : false;
38+
bool expect_scale = (abs(scale_v - 1.702) < 1e-5) ? true : false;
39+
bool has_act = const_cast<Node*>(node)->AsStmt().op_info()->HasAttr(
40+
"activation_type");
41+
return (expect_bias) && (expect_scale) && (!has_act);
42+
};
43+
44+
/* _____________________
45+
/ \
46+
Create node: X----scale----sigmoid---elementwise_mul---output
47+
*/
48+
auto* x = VarNode("x")->assert_is_op_input("scale", "X")->AsInput();
49+
auto* scale = OpNode("scale", "scale")->assert_node_satisfied(scale_teller);
50+
auto* scale_out = VarNode("scale_out");
51+
auto* sigmoid = OpNode("sigmoid", "sigmoid");
52+
auto* sigmoid_out = VarNode("sigmoid_out");
53+
auto* element_mul =
54+
OpNode("elementwise_mul", "elementwise_mul")
55+
->assert_op_attr_satisfied<int>(
56+
"axis", [](int attr) { return attr == -1 || attr == 0; });
57+
auto* output = VarNode("Out")->AsOutput();
58+
59+
// Construct the topological structure for scale-sigmoid-elementwise_mul
60+
*x >> *scale >> *scale_out >> *sigmoid >> *sigmoid_out;
61+
std::vector<PMNode*> element_mul_inputs{x, sigmoid_out};
62+
element_mul_inputs >> *element_mul >> *output;
63+
64+
// Some op specialities.
65+
scale->AsIntermediate();
66+
scale_out->AsIntermediate();
67+
sigmoid->AsIntermediate();
68+
sigmoid_out->AsIntermediate();
69+
element_mul->AsIntermediate();
70+
}
71+
72+
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) {
73+
auto op_desc = *matched.at("scale")->stmt()->op_info();
74+
float scale_val = op_desc.GetAttr<float>("scale");
75+
op_desc.mutable_inputs()->clear();
76+
op_desc.mutable_outputs()->clear();
77+
op_desc.SetType("__xpu__quick_gelu");
78+
op_desc.SetInput("X", {matched.at("x")->arg()->name});
79+
op_desc.SetOutput("Out", {matched.at("Out")->arg()->name});
80+
op_desc.SetAttr("scale", scale_val);
81+
return op_desc;
82+
}
83+
84+
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override {
85+
// get op_desc for gelu op.
86+
auto op_desc = GenOpDesc(matched);
87+
// Create gelu op.
88+
auto gelu_op = LiteOpRegistry::Global().Create("__xpu__quick_gelu");
89+
90+
// find scope and valid_places of original scale op.
91+
auto scale = matched.at("scale")->stmt()->op();
92+
auto* scope = scale->scope();
93+
auto& valid_places = scale->valid_places();
94+
95+
// set gelu op's scope and valid_places which aligned with scale op.
96+
gelu_op->Attach(op_desc, scope);
97+
auto* new_op_node = graph->GraphCreateInstructNode(gelu_op, valid_places);
98+
99+
// link IO to the new op node.
100+
IR_NODE_LINK_TO(matched.at("x"), new_op_node);
101+
IR_NODE_LINK_TO(new_op_node, matched.at("Out"));
102+
}
103+
};
104+
105+
} // namespace fusion
106+
107+
class XPUQuickGELUFusePass : public ProgramPass {
108+
public:
109+
void Apply(const std::unique_ptr<SSAGraph>& graph) override {
110+
fusion::XPUQuickGELUFuser fuser;
111+
fuser(graph.get());
112+
}
113+
};
114+
115+
} // namespace mir
116+
} // namespace lite
117+
} // namespace paddle
118+
119+
REGISTER_MIR_PASS(__xpu__quick_gelu_fuse_pass,
120+
paddle::lite::mir::XPUQuickGELUFusePass)
121+
.BindTargets({TARGET(kXPU)})
122+
.BindKernel("__xpu__quick_gelu");

lite/core/optimizer/optimizer.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ std::unique_ptr<RuntimeProgram> RunDefaultOptimizer(
201201
"__xpu__mmdnn_fuse_pass",
202202
"__xpu__bigru_fuse_pass",
203203
"__xpu__roformer_relative_pos_fuse_pass",
204+
"__xpu__quick_gelu_fuse_pass",
204205
"__xpu__multi_encoder_fuse_pass",
205206
"__xpu__embedding_with_eltwise_add_fuse_pass",
206207
"__xpu__fc_fuse_pass",

lite/kernels/xpu/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ add_kernel(__xpu__resnet50_compute_xpu XPU extra SRCS __xpu__resnet50_compute.cc
110110
add_kernel(__xpu__multi_encoder_compute_xpu XPU extra SRCS __xpu__multi_encoder_compute.cc)
111111
add_kernel(__xpu__embedding_with_eltwise_add_compute_xpu XPU extra SRCS __xpu__embedding_with_eltwise_add_compute.cc)
112112
add_kernel(__xpu__fc_compute_xpu XPU extra SRCS __xpu__fc_compute.cc)
113+
add_kernel(__xpu__quick_gelu_compute_xpu XPU extra SRCS __xpu__quick_gelu_compute.cc)
113114
add_kernel(__xpu__search_attention_compute_xpu XPU extra SRCS __xpu__search_attention_compute.cc)
114115
add_kernel(__xpu__search_attention_2_compute_xpu XPU extra SRCS __xpu__search_attention_2_compute.cc)
115116
add_kernel(__xpu__mmdnn_compute_xpu XPU extra SRCS __xpu__mmdnn_compute.cc)

lite/kernels/xpu/__xpu__multi_encoder_compute.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,8 @@ void XPUMultiEncoderCompute::PrepareForRun() {
196196
// prepare act_type
197197
if (param.act_type == "gelu") {
198198
qkv_act = xdnn::Activation_t::GELU;
199+
} else if (param.act_type == "__xpu__quick_gelu") {
200+
qkv_act = xdnn::Activation_t::QUICK_GELU;
199201
} else if (param.act_type != "relu") {
200202
CHECK(false) << "Invalid QKV Activation Type: " << param.act_type;
201203
}
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "lite/kernels/xpu/__xpu__quick_gelu_compute.h"
16+
#include "lite/backends/xpu/xpu_header_sitter.h"
17+
#include "lite/core/op_registry.h"
18+
19+
namespace paddle {
20+
namespace lite {
21+
namespace kernels {
22+
namespace xpu {
23+
24+
template <typename T, PrecisionType PType>
25+
void QuickGeluCompute<T, PType>::Run() {
26+
auto& param = this->template Param<param_t>();
27+
auto& ctx = this->ctx_->template As<XPUContext>();
28+
29+
int r = xdnn::quick_gelu(ctx.GetRawContext(),
30+
param.X->template data<T>(),
31+
param.Out->template mutable_data<T>(TARGET(kXPU)),
32+
param.X->numel());
33+
CHECK_EQ(r, 0);
34+
}
35+
36+
} // namespace xpu
37+
} // namespace kernels
38+
} // namespace lite
39+
} // namespace paddle
40+
41+
using quick_gelu_FP32 =
42+
paddle::lite::kernels::xpu::QuickGeluCompute<float, PRECISION(kFloat)>;
43+
using qucik_gelu_FP16 =
44+
paddle::lite::kernels::xpu::QuickGeluCompute<float16, PRECISION(kFP16)>;
45+
REGISTER_LITE_KERNEL(
46+
__xpu__quick_gelu, kXPU, kFloat, kNCHW, quick_gelu_FP32, def)
47+
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))})
48+
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))})
49+
.Finalize();
50+
REGISTER_LITE_KERNEL(
51+
__xpu__quick_gelu, kXPU, kFP16, kNCHW, qucik_gelu_FP16, qucik_gelu_FP16)
52+
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))})
53+
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))})
54+
.Finalize();
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
#include "lite/core/kernel.h"
17+
18+
namespace paddle {
19+
namespace lite {
20+
namespace kernels {
21+
namespace xpu {
22+
23+
template <typename T, PrecisionType PType>
24+
class QuickGeluCompute : public KernelLite<TARGET(kXPU), PType> {
25+
public:
26+
using param_t = operators::XPUQuickGeluParam;
27+
28+
virtual void Run();
29+
30+
virtual ~QuickGeluCompute() = default;
31+
};
32+
33+
} // namespace xpu
34+
} // namespace kernels
35+
} // namespace lite
36+
} // namespace paddle

lite/operators/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ add_operator(__xpu__softmax_topk_op extra SRCS __xpu__softmax_topk_op.cc)
233233
add_operator(__xpu__multi_encoder_op extra SRCS __xpu__multi_encoder_op.cc)
234234
add_operator(__xpu__embedding_with_eltwise_add_op extra SRCS __xpu__embedding_with_eltwise_add_op.cc)
235235
add_operator(__xpu__fc_op extra SRCS __xpu__fc_op.cc)
236+
add_operator(__xpu__quick_gelu_op extra SRCS __xpu__quick_gelu_op.cc)
236237
add_operator(__xpu__roformer_relative_embedding_op extra SRCS __xpu__roformer_relative_embedding_op.cc)
237238
add_operator(__xpu__search_attention_op extra SRCS __xpu__search_attention_op.cc)
238239
add_operator(__xpu__mmdnn_op extra SRCS __xpu__mmdnn_op.cc)

0 commit comments

Comments
 (0)