Skip to content

Commit 45b5712

Browse files
author
shenyijun01
committed
[Optimizer]: add quick gelu fusion pass for ViT model.
1 parent 382489d commit 45b5712

12 files changed

Lines changed: 324 additions & 2 deletions

lite/api/paddle_use_passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ USE_MIR_PASS(__xpu__resnet_fuse_pass);
8181
USE_MIR_PASS(__xpu__multi_encoder_fuse_pass);
8282
USE_MIR_PASS(__xpu__embedding_with_eltwise_add_fuse_pass);
8383
USE_MIR_PASS(__xpu__fc_fuse_pass);
84+
USE_MIR_PASS(__xpu__quick_gelu_fuse_pass);
8485
USE_MIR_PASS(__xpu__mmdnn_fuse_pass);
8586
USE_MIR_PASS(__xpu__conv2d_affine_channel_fuse_pass);
8687
USE_MIR_PASS(__xpu__conv2d_fuse_pass);

lite/core/optimizer/mir/fusion/__xpu__fc_fuse_pass.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,8 @@ class XPUFcFuser : public FuseBase {
199199
{"leaky_relu", 5},
200200
{"hard_swish", 14},
201201
{"hard_sigmoid", 15},
202-
{"relu6", 17}};
202+
{"relu6", 17},
203+
{"__xpu__quick_gelu", 19}};
203204

204205
float act_param_ = 0.0f;
205206
if (act_type_ == "leaky_relu") {
@@ -281,6 +282,7 @@ class XPUFcFusePass : public ProgramPass {
281282
for (auto with_bias : {true, false}) {
282283
for (auto act_type : {"relu",
283284
"gelu",
285+
"__xpu__quick_gelu",
284286
/*"sigmoid",
285287
"tanh",
286288
"leaky_relu",
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <memory>
16+
#include <string>
17+
#include "lite/backends/xpu/math.h"
18+
#include "lite/core/optimizer/mir/pass_registry.h"
19+
#include "lite/core/optimizer/mir/pattern_matcher_high_api.h"
20+
21+
namespace paddle {
22+
namespace lite {
23+
namespace mir {
24+
namespace fusion {
25+
26+
class XPUQuickGELUFuser : public FuseBase {
27+
public:
28+
explicit XPUQuickGELUFuser(const std::string& op_type,
29+
const std::string& act_type) {
30+
op_type_ = op_type;
31+
act_type_ = act_type;
32+
}
33+
34+
void BuildPattern() override {
35+
auto scale_teller = [](const Node* node) -> bool {
36+
bool bias_after_scale =
37+
const_cast<Node*>(node)->AsStmt().op_info()->GetAttr<bool>(
38+
"bias_after_scale");
39+
bool has_act = const_cast<Node*>(node)->AsStmt().op_info()->HasAttr(
40+
"activation_type");
41+
return bias_after_scale && (!has_act);
42+
};
43+
44+
/* _____________________
45+
/ \
46+
Create node: X----scale----sigmoid---elementwise_mul---output
47+
*/
48+
auto* x = VarNode("x")->assert_is_op_input("scale", "X");
49+
auto* scale = OpNode("scale", "scale")
50+
->assert_is_op("scale")
51+
->assert_node_satisfied(scale_teller);
52+
auto* scale_out = VarNode("scale_out");
53+
auto* sigmoid = OpNode("sigmoid", act_type_);
54+
auto* sigmoid_out = VarNode("sigmoid_out");
55+
auto* element_mul =
56+
OpNode("elementwise_mul", op_type_)
57+
->assert_op_attr_satisfied<int>(
58+
"axis", [](int attr) { return attr == -1 || attr == 0; });
59+
auto* output = VarNode("Out");
60+
61+
// Construct the topological structure for scale-sigmoid-elementwise_mul
62+
*x >> *scale >> *scale_out >> *sigmoid >> *sigmoid_out;
63+
std::vector<PMNode*> element_mul_inputs{x, sigmoid_out};
64+
element_mul_inputs >> *element_mul >> *output;
65+
66+
// Some op specialities.
67+
scale->AsIntermediate();
68+
scale_out->AsIntermediate();
69+
sigmoid->AsIntermediate();
70+
sigmoid_out->AsIntermediate();
71+
element_mul->AsIntermediate();
72+
}
73+
74+
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) {
75+
auto op_desc = *matched.at("scale")->stmt()->op_info();
76+
float scale_val = op_desc.GetAttr<float>("scale");
77+
op_desc.mutable_inputs()->clear();
78+
op_desc.mutable_outputs()->clear();
79+
op_desc.SetType("__xpu__quick_gelu");
80+
op_desc.SetInput("X", {matched.at("x")->arg()->name});
81+
op_desc.SetOutput("Out", {matched.at("Out")->arg()->name});
82+
op_desc.SetAttr("scale", scale_val);
83+
return op_desc;
84+
}
85+
86+
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override {
87+
// get op_desc for gelu op.
88+
auto op_desc = GenOpDesc(matched);
89+
// Create gelu op.
90+
auto gelu_op = LiteOpRegistry::Global().Create("__xpu__quick_gelu");
91+
92+
// find scope and valid_places of original scale op.
93+
auto scale = matched.at("scale")->stmt()->op();
94+
auto* scope = scale->scope();
95+
auto& valid_places = scale->valid_places();
96+
97+
// set gelu op's scope and valid_places which aligned with scale op.
98+
gelu_op->Attach(op_desc, scope);
99+
auto* new_op_node = graph->GraphCreateInstructNode(gelu_op, valid_places);
100+
101+
// link IO to the new op node.
102+
IR_NODE_LINK_TO(matched.at("x"), new_op_node);
103+
IR_NODE_LINK_TO(new_op_node, matched.at("Out"));
104+
}
105+
106+
private:
107+
std::string op_type_;
108+
std::string act_type_;
109+
};
110+
111+
} // namespace fusion
112+
113+
class XPUQuickGELUFusePass : public ProgramPass {
114+
public:
115+
void Apply(const std::unique_ptr<SSAGraph>& graph) override {
116+
fusion::XPUQuickGELUFuser fuser("elementwise_mul", "sigmoid");
117+
fuser(graph.get());
118+
}
119+
};
120+
121+
} // namespace mir
122+
} // namespace lite
123+
} // namespace paddle
124+
125+
REGISTER_MIR_PASS(__xpu__quick_gelu_fuse_pass,
126+
paddle::lite::mir::XPUQuickGELUFusePass)
127+
.BindTargets({TARGET(kXPU)})
128+
.BindKernel("__xpu__quick_gelu");

lite/core/optimizer/optimizer.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ std::unique_ptr<RuntimeProgram> RunDefaultOptimizer(
200200
"__xpu__mmdnn_fuse_pass",
201201
"__xpu__bigru_fuse_pass",
202202
"__xpu__roformer_relative_pos_fuse_pass",
203+
"__xpu__quick_gelu_fuse_pass",
203204
"__xpu__multi_encoder_fuse_pass",
204205
"__xpu__embedding_with_eltwise_add_fuse_pass",
205206
"__xpu__fc_fuse_pass",

lite/kernels/xpu/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ add_kernel(__xpu__resnet50_compute_xpu XPU extra SRCS __xpu__resnet50_compute.cc
110110
add_kernel(__xpu__multi_encoder_compute_xpu XPU extra SRCS __xpu__multi_encoder_compute.cc)
111111
add_kernel(__xpu__embedding_with_eltwise_add_compute_xpu XPU extra SRCS __xpu__embedding_with_eltwise_add_compute.cc)
112112
add_kernel(__xpu__fc_compute_xpu XPU extra SRCS __xpu__fc_compute.cc)
113+
add_kernel(__xpu__quick_gelu_compute_xpu XPU extra SRCS __xpu__quick_gelu_compute.cc)
113114
add_kernel(__xpu__search_attention_compute_xpu XPU extra SRCS __xpu__search_attention_compute.cc)
114115
add_kernel(__xpu__search_attention_2_compute_xpu XPU extra SRCS __xpu__search_attention_2_compute.cc)
115116
add_kernel(__xpu__mmdnn_compute_xpu XPU extra SRCS __xpu__mmdnn_compute.cc)
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "lite/kernels/xpu/__xpu__quick_gelu_compute.h"
16+
#include "lite/backends/xpu/xpu_header_sitter.h"
17+
#include "lite/core/op_registry.h"
18+
19+
namespace paddle {
20+
namespace lite {
21+
namespace kernels {
22+
namespace xpu {
23+
24+
template <typename T, PrecisionType PType>
25+
void QuickGeluCompute<T, PType>::Run() {
26+
auto& param = this->template Param<param_t>();
27+
auto& ctx = this->ctx_->template As<XPUContext>();
28+
29+
int r = xdnn::quick_gelu(ctx.GetRawContext(),
30+
param.X->template data<T>(),
31+
param.Out->template mutable_data<T>(TARGET(kXPU)),
32+
param.X->numel());
33+
CHECK_EQ(r, 0);
34+
}
35+
36+
} // namespace xpu
37+
} // namespace kernels
38+
} // namespace lite
39+
} // namespace paddle
40+
41+
using quick_gelu_FP32 =
42+
paddle::lite::kernels::xpu::QuickGeluCompute<float, PRECISION(kFloat)>;
43+
using qucik_gelu_FP16 =
44+
paddle::lite::kernels::xpu::QuickGeluCompute<float16, PRECISION(kFP16)>;
45+
REGISTER_LITE_KERNEL(quick_gelu, kXPU, kFloat, kNCHW, quick_gelu_FP32, def)
46+
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))})
47+
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))})
48+
.Finalize();
49+
REGISTER_LITE_KERNEL(
50+
quick_gelu, kXPU, kFP16, kNCHW, qucik_gelu_FP16, qucik_gelu_FP16)
51+
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))})
52+
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))})
53+
.Finalize();
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
#include "lite/core/kernel.h"
17+
18+
namespace paddle {
19+
namespace lite {
20+
namespace kernels {
21+
namespace xpu {
22+
23+
template <typename T, PrecisionType PType>
24+
class QuickGeluCompute : public KernelLite<TARGET(kXPU), PType> {
25+
public:
26+
using param_t = operators::XPUQuickGeluParam;
27+
28+
virtual void Run();
29+
30+
virtual ~QuickGeluCompute() = default;
31+
};
32+
33+
} // namespace xpu
34+
} // namespace kernels
35+
} // namespace lite
36+
} // namespace paddle

lite/operators/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ add_operator(__xpu__softmax_topk_op extra SRCS __xpu__softmax_topk_op.cc)
233233
add_operator(__xpu__multi_encoder_op extra SRCS __xpu__multi_encoder_op.cc)
234234
add_operator(__xpu__embedding_with_eltwise_add_op extra SRCS __xpu__embedding_with_eltwise_add_op.cc)
235235
add_operator(__xpu__fc_op extra SRCS __xpu__fc_op.cc)
236+
add_operator(__xpu__quick_gelu_op extra SRCS __xpu__quick_gelu_op.cc)
236237
add_operator(__xpu__roformer_relative_embedding_op extra SRCS __xpu__roformer_relative_embedding_op.cc)
237238
add_operator(__xpu__search_attention_op extra SRCS __xpu__search_attention_op.cc)
238239
add_operator(__xpu__mmdnn_op extra SRCS __xpu__mmdnn_op.cc)
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.i
14+
15+
#include "lite/operators/__xpu__quick_gelu_op.h"
16+
#include "lite/core/op_registry.h"
17+
18+
namespace paddle {
19+
namespace lite {
20+
namespace operators {
21+
22+
bool XPUQuickGeluOp::CheckShape() const {
23+
CHECK_OR_FALSE(param_.X);
24+
CHECK_OR_FALSE(param_.Out);
25+
return true;
26+
}
27+
28+
bool XPUQuickGeluOp::InferShapeImpl() const {
29+
param_.Out->Resize(param_.X->dims());
30+
auto out_lod = param_.Out->mutable_lod();
31+
*out_lod = param_.X->lod();
32+
return true;
33+
}
34+
35+
bool XPUQuickGeluOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) {
36+
auto x_name = opdesc.Input("X").front();
37+
auto out_name = opdesc.Output("Out").front();
38+
param_.X = scope->FindVar(x_name)->GetMutable<lite::Tensor>();
39+
param_.Out = scope->FindVar(out_name)->GetMutable<lite::Tensor>();
40+
return true;
41+
}
42+
43+
} // namespace operators
44+
} // namespace lite
45+
} // namespace paddle
46+
47+
REGISTER_LITE_OP(__xpu__quick_gelu, paddle::lite::operators::XPUQuickGeluOp);
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
#include <string>
17+
#include "lite/core/op_lite.h"
18+
#ifdef LITE_WITH_PROFILE
19+
#include "lite/api/paddle_place.h"
20+
#endif
21+
22+
namespace paddle {
23+
namespace lite {
24+
namespace operators {
25+
26+
class XPUQuickGeluOp : public OpLite {
27+
public:
28+
explicit XPUQuickGeluOp(const std::string& type) : OpLite(type) {}
29+
30+
bool CheckShape() const override;
31+
32+
bool InferShapeImpl() const override;
33+
34+
bool InferType() override { return true; }
35+
36+
bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override;
37+
38+
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); }
39+
40+
std::string DebugString() const override { return "XPUQuickGelu"; }
41+
42+
private:
43+
mutable operators::XPUQuickGeluParam param_;
44+
};
45+
46+
} // namespace operators
47+
} // namespace lite
48+
} // namespace paddle

0 commit comments

Comments
 (0)