Skip to content

Commit 24bf512

Browse files
LLee233co63oc
authored andcommitted
[PIR][oneDNN] Add matmul_transpose_reshape_fuse_pass (PaddlePaddle#63151)
1 parent 9f582f3 commit 24bf512

5 files changed

Lines changed: 468 additions & 0 deletions

File tree

paddle/fluid/inference/api/paddle_pass_builder.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,7 @@ const std::vector<std::string> kPirMkldnnPasses{
627627
"conv3d_bias_fuse_pass",
628628
"batch_norm_act_fuse_pass",
629629
"reshape_transpose_matmul_fuse_pass",
630+
"matmul_transpose_reshape_fuse_pass",
630631
"matmul_elementwise_add_fuse_pass",
631632
"matmul_activation_fuse_pass",
632633
"conv_elementwise_add_onednn_fuse_pass"};
Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/pir/transforms/onednn/matmul_transpose_reshape_fuse_pass.h"
16+
17+
#include "paddle/fluid/pir/dialect/operator/ir/onednn_op.h"
18+
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
19+
#include "paddle/fluid/pir/drr/include/drr_pattern_base.h"
20+
21+
#include "paddle/pir/include/pass/pass.h"
22+
#include "paddle/pir/include/pass/pass_registry.h"
23+
24+
namespace {
25+
class MatmulTransposeReshapeFusePattern : public paddle::drr::DrrPatternBase {
26+
private:
27+
std::string matmul_name_;
28+
std::string fused_matmul_name_;
29+
uint32_t benefit_;
30+
31+
public:
32+
MatmulTransposeReshapeFusePattern(const std::string &matmul_name,
33+
const std::string &fused_matmul_name,
34+
uint32_t benefit)
35+
: matmul_name_(matmul_name),
36+
fused_matmul_name_(fused_matmul_name),
37+
benefit_(benefit) {}
38+
39+
std::string name() const override {
40+
return "MatmulTransposeReshapeFusePattern";
41+
}
42+
43+
uint32_t benefit() const override { return benefit_; }
44+
45+
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
46+
paddle::drr::SourcePattern pat = ctx->SourcePattern();
47+
48+
const auto &matmul = pat.Op(matmul_name_,
49+
{{"transpose_x", pat.Attr("transpose_x")},
50+
{"transpose_y", pat.Attr("transpose_y")}});
51+
matmul({&pat.Tensor("X"), &pat.Tensor("Y")}, {&pat.Tensor("Out")});
52+
53+
const auto &transpose = pat.Op(paddle::dialect::TransposeOp::name(),
54+
{{"perm", pat.Attr("perm")}});
55+
pat.Tensor("transpose_out") = transpose(pat.Tensor("Out"));
56+
57+
const auto &full_int_array = pat.Op(paddle::dialect::FullIntArrayOp::name(),
58+
{{"value", pat.Attr("int_array")}});
59+
pat.Tensor("shape") = full_int_array();
60+
61+
const auto &reshape = pat.Op(paddle::dialect::ReshapeOp::name());
62+
reshape({&pat.Tensor("transpose_out"), &pat.Tensor("shape")},
63+
{&pat.Tensor("reshape_out"), &pat.Tensor("Xshape")});
64+
65+
pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) {
66+
std::set<bool> bool_sets = {true, false};
67+
auto result_x = match_ctx.Attr<bool>("transpose_x");
68+
auto result_y = match_ctx.Attr<bool>("transpose_y");
69+
if (bool_sets.count(result_x) == 0 || bool_sets.count(result_y) == 0) {
70+
return false;
71+
}
72+
return true;
73+
});
74+
75+
pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) {
76+
auto shape = match_ctx.Attr<std::vector<int64_t>>("int_array");
77+
auto perm = match_ctx.Attr<std::vector<int>>("perm");
78+
const std::vector<int> supported_axis{0, 2, 1, 3};
79+
if (perm != supported_axis) return false;
80+
if (shape.size() != 3) return false;
81+
if (std::count(shape.begin(), shape.end(), -1) > 1) return false;
82+
return true;
83+
});
84+
85+
paddle::drr::ResultPattern res = pat.ResultPattern();
86+
87+
std::unordered_map<std::string, paddle::drr::Attribute> fused_attrs{
88+
{"trans_x", pat.Attr("transpose_x")},
89+
{"trans_y", pat.Attr("transpose_y")},
90+
{"matmul_alpha", res.Float32Attr(1.0f)},
91+
{"fuse_activation", res.StrAttr("")},
92+
{"fuse_alpha", res.Float32Attr(0.0f)},
93+
{"fuse_beta", res.Float32Attr(0.0f)},
94+
{"fused_output_scale", res.Float32Attr(1.0f)},
95+
{"fused_reshape_x", res.VectorInt32Attr({})},
96+
{"fused_transpose_x", res.VectorInt32Attr({})},
97+
{"fused_reshape_y", res.VectorInt32Attr({})},
98+
{"fused_transpose_y", res.VectorInt32Attr({})},
99+
{"mkldnn_data_type", res.StrAttr("float32")},
100+
{"scale_x", res.Float32Attr(1.0f)},
101+
{"scale_y", res.Float32Attr(1.0f)},
102+
{"scale_in_eltwise", res.Float32Attr(0.0f)},
103+
{"scale_out", res.Float32Attr(1.0f)},
104+
{"force_fp32_output", res.BoolAttr(false)}};
105+
106+
const auto &fused_reshape_attr = res.ComputeAttr(
107+
[](const paddle::drr::MatchContext &match_ctx) -> std::vector<int> {
108+
std::vector<int> int_array_value;
109+
auto shape = match_ctx.Attr<std::vector<int64_t>>("int_array");
110+
for (auto i : shape) {
111+
int_array_value.emplace_back(static_cast<int>(i));
112+
}
113+
return int_array_value;
114+
});
115+
116+
fused_attrs.emplace("fused_reshape_out", fused_reshape_attr);
117+
fused_attrs.emplace("fused_transpose_out", pat.Attr("perm"));
118+
119+
const auto &fused_matmul = res.Op(fused_matmul_name_, fused_attrs);
120+
121+
fused_matmul({&res.Tensor("X"), &res.Tensor("Y"), &res.InputNoneTensor()},
122+
{&res.Tensor("reshape_out")});
123+
}
124+
};
125+
126+
class FusedMatmulTransposeReshapeFusePattern
127+
: public paddle::drr::DrrPatternBase {
128+
private:
129+
std::string matmul_name_;
130+
std::string fused_matmul_name_;
131+
uint32_t benefit_;
132+
133+
public:
134+
FusedMatmulTransposeReshapeFusePattern(const std::string &matmul_name,
135+
const std::string &fused_matmul_name,
136+
uint32_t benefit)
137+
: matmul_name_(matmul_name),
138+
fused_matmul_name_(fused_matmul_name),
139+
benefit_(benefit) {}
140+
141+
std::string name() const override {
142+
return "FusedMatmulTransposeReshapeFusePattern";
143+
}
144+
145+
uint32_t benefit() const override { return benefit_; }
146+
147+
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
148+
paddle::drr::SourcePattern pat = ctx->SourcePattern();
149+
150+
const auto &matmul =
151+
pat.Op(matmul_name_,
152+
{{"trans_x", pat.Attr("transpose_x")},
153+
{"trans_y", pat.Attr("transpose_y")},
154+
{"matmul_alpha", pat.Attr("matmul_alpha")},
155+
{"fuse_activation", pat.Attr("fuse_activation")},
156+
{"fuse_alpha", pat.Attr("fuse_alpha")},
157+
{"fuse_beta", pat.Attr("fuse_beta")},
158+
{"fused_output_scale", pat.Attr("fused_output_scale")},
159+
{"fused_reshape_x", pat.Attr("fused_reshape_x")},
160+
{"fused_transpose_x", pat.Attr("fused_transpose_x")},
161+
{"fused_reshape_y", pat.Attr("fused_reshape_y")},
162+
{"fused_transpose_y", pat.Attr("fused_transpose_y")},
163+
{"fused_reshape_out", pat.Attr("fused_reshape_out")},
164+
{"fused_transpose_out", pat.Attr("fused_transpose_out")},
165+
{"mkldnn_data_type", pat.Attr("mkldnn_data_type")},
166+
{"scale_x", pat.Attr("scale_x")},
167+
{"scale_y", pat.Attr("scale_y")},
168+
{"scale_in_eltwise", pat.Attr("scale_in_eltwise")},
169+
{"scale_out", pat.Attr("scale_out")},
170+
{"force_fp32_output", pat.Attr("force_fp32_output")}});
171+
172+
matmul({&pat.Tensor("X"), &pat.Tensor("Y"), &pat.Tensor("residual")},
173+
{&pat.Tensor("Out")});
174+
175+
const auto &transpose = pat.Op(paddle::dialect::TransposeOp::name(),
176+
{{"perm", pat.Attr("perm")}});
177+
pat.Tensor("transpose_out") = transpose(pat.Tensor("Out"));
178+
179+
const auto &full_int_array = pat.Op(paddle::dialect::FullIntArrayOp::name(),
180+
{{"value", pat.Attr("int_array")}});
181+
pat.Tensor("shape") = full_int_array();
182+
183+
const auto &reshape = pat.Op(paddle::dialect::ReshapeOp::name());
184+
reshape({&pat.Tensor("transpose_out"), &pat.Tensor("shape")},
185+
{&pat.Tensor("reshape_out"), &pat.Tensor("Xshape")});
186+
187+
pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) {
188+
auto shape = match_ctx.Attr<std::vector<int64_t>>("int_array");
189+
auto perm = match_ctx.Attr<std::vector<int>>("perm");
190+
const std::vector<int> supported_axis{0, 2, 1, 3};
191+
if (perm != supported_axis) return false;
192+
if (shape.size() != 3) return false;
193+
if (std::count(shape.begin(), shape.end(), -1) > 1) return false;
194+
195+
return true;
196+
});
197+
198+
pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) {
199+
if (!(match_ctx.Attr<std::vector<int>>("fused_reshape_out").empty()))
200+
return false;
201+
return true;
202+
});
203+
204+
paddle::drr::ResultPattern res = pat.ResultPattern();
205+
206+
std::unordered_map<std::string, paddle::drr::Attribute> fused_attrs{
207+
{"trans_x", pat.Attr("transpose_x")},
208+
{"trans_y", pat.Attr("transpose_y")},
209+
{"matmul_alpha", pat.Attr("matmul_alpha")},
210+
{"fuse_activation", pat.Attr("fuse_activation")},
211+
{"fuse_alpha", pat.Attr("fuse_alpha")},
212+
{"fuse_beta", pat.Attr("fuse_beta")},
213+
{"fused_output_scale", pat.Attr("fused_output_scale")},
214+
{"fused_reshape_x", pat.Attr("fused_reshape_x")},
215+
{"fused_transpose_x", pat.Attr("fused_transpose_x")},
216+
{"fused_reshape_y", pat.Attr("fused_reshape_y")},
217+
{"fused_transpose_y", pat.Attr("fused_transpose_y")},
218+
{"mkldnn_data_type", pat.Attr("mkldnn_data_type")},
219+
{"scale_x", pat.Attr("scale_x")},
220+
{"scale_y", pat.Attr("scale_y")},
221+
{"scale_in_eltwise", pat.Attr("scale_in_eltwise")},
222+
{"scale_out", pat.Attr("scale_out")},
223+
{"force_fp32_output", pat.Attr("force_fp32_output")}};
224+
225+
const auto &fused_reshape_attr = res.ComputeAttr(
226+
[](const paddle::drr::MatchContext &match_ctx) -> std::vector<int> {
227+
std::vector<int> int_array_value;
228+
auto shape = match_ctx.Attr<std::vector<int64_t>>("int_array");
229+
for (auto i : shape) {
230+
int_array_value.emplace_back(static_cast<int>(i));
231+
}
232+
return int_array_value;
233+
});
234+
235+
fused_attrs.emplace("fused_reshape_out", fused_reshape_attr);
236+
fused_attrs.emplace("fused_transpose_out", pat.Attr("perm"));
237+
238+
const auto &fused_matmul = res.Op(fused_matmul_name_, fused_attrs);
239+
240+
fused_matmul({&res.Tensor("X"), &res.Tensor("Y"), &res.Tensor("residual")},
241+
{&res.Tensor("reshape_out")});
242+
}
243+
};
244+
245+
class MatmulTransposeReshapeFusePass : public pir::PatternRewritePass {
246+
public:
247+
MatmulTransposeReshapeFusePass()
248+
: pir::PatternRewritePass("matmul_transpose_reshape_fuse_pass", 3) {}
249+
250+
pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override {
251+
pir::RewritePatternSet ps(context);
252+
std::vector<bool> bool_set = {false, true};
253+
int benefit_idx = 1;
254+
ps.Add(paddle::drr::Create<MatmulTransposeReshapeFusePattern>(
255+
context,
256+
paddle::dialect::MatmulOp::name(),
257+
paddle::onednn::dialect::FusedMatmulOp::name(),
258+
benefit_idx++));
259+
260+
ps.Add(paddle::drr::Create<FusedMatmulTransposeReshapeFusePattern>(
261+
context,
262+
paddle::onednn::dialect::FusedMatmulOp::name(),
263+
paddle::onednn::dialect::FusedMatmulOp::name(),
264+
benefit_idx++));
265+
return ps;
266+
}
267+
};
268+
269+
} // namespace
270+
271+
namespace pir {
272+
273+
std::unique_ptr<Pass> CreateMatmulTransposeReshapeFusePass() {
274+
// pd_op.matmul + pd_op.transpose + pd_op.reshape -> onednn_op.fused_matmul
275+
// pd_op.fused_matmul + pd_op.transpose + pd_op.reshape ->
276+
// onednn_op.fused_matmul
277+
return std::make_unique<MatmulTransposeReshapeFusePass>();
278+
}
279+
} // namespace pir
280+
281+
REGISTER_IR_PASS(matmul_transpose_reshape_fuse_pass,
282+
MatmulTransposeReshapeFusePass);
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <memory>
18+
#include "paddle/pir/include/core/dll_decl.h"
19+
20+
namespace pir {
21+
22+
class Pass;
23+
24+
IR_API std::unique_ptr<Pass> CreateMatmulTransposeReshapeFusePass();
25+
26+
} // namespace pir

paddle/fluid/pir/transforms/passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ USE_PIR_PASS(conv2d_bias_fuse_pass);
4747
USE_PIR_PASS(conv2d_transpose_bias_fuse_pass);
4848
USE_PIR_PASS(conv3d_bias_fuse_pass);
4949
USE_PIR_PASS(reshape_transpose_matmul_fuse_pass);
50+
USE_PIR_PASS(matmul_transpose_reshape_fuse_pass);
5051
USE_PIR_PASS(matmul_elementwise_add_fuse_pass);
5152
USE_PIR_PASS(matmul_activation_fuse_pass);
5253
USE_PIR_PASS(conv_elementwise_add_onednn_fuse_pass);

0 commit comments

Comments
 (0)