Skip to content

Commit ea846ed

Browse files
committed
support drr result pattern many simple attr
1 parent 0f7ec3c commit ea846ed

16 files changed

Lines changed: 880 additions & 367 deletions

paddle/fluid/pir/drr/README.md

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -182,14 +182,10 @@ class FusedLinearPattern : public paddle::drr::DrrPatternBase {
182182
// Define ResultPattern
183183
paddle::drr::ResultPattern res = pat.ResultPattern();
184184
// Define Constrain
185-
const auto &act_attr =
186-
res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any {
187-
return "none";
188-
});
189185
const auto &fused_gemm_epilogue = res.Op(paddle::dialect::FusedGemmEpilogueOp::name(),
190186
{{{"trans_x", pat.Attr("trans_x")},
191187
{"trans_y", pat.Attr("trans_y")},
192-
{"activation", act_attr}}});
188+
{"activation", res.StrAttr("none")}}});
193189
fused_gemm_epilogue(
194190
{&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("bias")},
195191
{&res.Tensor("out")});

paddle/fluid/pir/drr/README_cn.md

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -185,14 +185,10 @@ class FusedLinearPattern : public paddle::drr::DrrPatternBase {
185185
// 定义 Result Pattern
186186
paddle::drr::ResultPattern res = pat.ResultPattern();
187187
// 定义 Constrain
188-
const auto &act_attr =
189-
res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any {
190-
return "none";
191-
});
192188
const auto &fused_gemm_epilogue = res.Op(paddle::dialect::FusedGemmEpilogueOp::name(),
193189
{{{"trans_x", pat.Attr("trans_x")},
194190
{"trans_y", pat.Attr("trans_y")},
195-
{"activation", act_attr}}});
191+
{"activation", res.StrAttr("none")}}});
196192
fused_gemm_epilogue(
197193
{&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("bias")},
198194
{&res.Tensor("out")});

paddle/fluid/pir/drr/include/drr_pattern_context.h

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#pragma once
1616

1717
#include <any>
18+
#include <cstdint>
1819
#include <functional>
1920
#include <memory>
2021
#include <string>
@@ -280,10 +281,46 @@ class ResultPattern {
280281
return ctx_->ResultTensorPattern(Tensor::NONE_TENSOR_NAME);
281282
}
282283

283-
Attribute Attr(const std::string& attr_name) const {
284-
return NormalAttribute(attr_name);
284+
Attribute StrAttr(const std::string& value) const {
285+
return ComputeAttr(
286+
[&](const MatchContext& match_ctx) -> std::string { return value; });
287+
}
288+
289+
Attribute BoolAttr(bool value) const {
290+
return ComputeAttr(
291+
[&](const MatchContext& match_ctx) -> bool { return value; });
292+
}
293+
294+
Attribute Int32Attr(int32_t value) const {
295+
return ComputeAttr(
296+
[&](const MatchContext& match_ctx) -> int32_t { return value; });
297+
}
298+
299+
Attribute Int64Attr(int64_t value) const {
300+
return ComputeAttr(
301+
[&](const MatchContext& match_ctx) -> int64_t { return value; });
285302
}
286-
Attribute Attr(const AttrComputeFunc& attr_compute_func) const {
303+
304+
Attribute Float32Attr(float value) const {
305+
return ComputeAttr(
306+
[&](const MatchContext& match_ctx) -> float { return value; });
307+
}
308+
309+
Attribute VectorInt64Attr(const std::vector<int64_t>& value) const {
310+
return ComputeAttr(
311+
[&](const MatchContext& match_ctx) -> std::vector<int64_t> {
312+
return value;
313+
});
314+
}
315+
316+
Attribute VectorInt32Attr(const std::vector<int32_t>& value) const {
317+
return ComputeAttr(
318+
[&](const MatchContext& match_ctx) -> std::vector<int32_t> {
319+
return value;
320+
});
321+
}
322+
323+
Attribute ComputeAttr(const AttrComputeFunc& attr_compute_func) const {
287324
return ComputeAttribute(attr_compute_func);
288325
}
289326

paddle/fluid/pir/drr/ir_operation_factory.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h"
2020
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
2121
#include "paddle/fluid/pir/drr/attr_type_uilts.h"
22+
#include "paddle/fluid/pir/drr/include/drr_pattern_context.h"
2223
#include "paddle/phi/core/enforce.h"
2324
#include "paddle/pir/core/builtin_op.h"
2425
#include "paddle/pir/core/operation.h"

paddle/fluid/pir/transforms/fusion/attention_fuse_pass.cc

Lines changed: 30 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ class MultiHeadMatmulFuseNoBiasQKPattern : public paddle::drr::DrrPatternBase {
170170
paddle::drr::ResultPattern res = src.ResultPattern();
171171

172172
// W reshape.
173-
const auto &reshape_w_shape_attr = res.Attr(
173+
const auto &reshape_w_shape_attr = res.ComputeAttr(
174174
[](const paddle::drr::MatchContext &match_ctx) -> std::vector<int64_t> {
175175
auto matmul_1_in_2 =
176176
pir::GetShapeFromValue(match_ctx.Tensor("matmul_1_in_2"));
@@ -195,14 +195,12 @@ class MultiHeadMatmulFuseNoBiasQKPattern : public paddle::drr::DrrPatternBase {
195195
&res.Tensor("reshape_6_out"),
196196
&res.Tensor("reshape_7_out")},
197197
{&res.Tensor("combine_1_out")});
198-
const auto &concat_1_axis_attr = res.Attr(
199-
[](const paddle::drr::MatchContext &match_ctx) -> int { return 1; });
200-
const auto &concat_1 =
201-
res.Op("pd_op.concat", {{"axis", concat_1_axis_attr}});
198+
199+
const auto &concat_1 = res.Op("pd_op.concat", {{"axis", res.Int32Attr(1)}});
202200
res.Tensor("concat_1_out") = concat_1(res.Tensor("combine_1_out"));
203201

204202
// Bias reshape.
205-
const auto &reshape_b_shape_attr = res.Attr(
203+
const auto &reshape_b_shape_attr = res.ComputeAttr(
206204
[](const paddle::drr::MatchContext &match_ctx) -> std::vector<int64_t> {
207205
auto add_1_in_2 =
208206
pir::GetShapeFromValue(match_ctx.Tensor("add_1_in_2"));
@@ -227,38 +225,26 @@ class MultiHeadMatmulFuseNoBiasQKPattern : public paddle::drr::DrrPatternBase {
227225
&res.Tensor("reshape_9_out"),
228226
&res.Tensor("reshape_10_out")},
229227
{&res.Tensor("combine_2_out")});
230-
const auto &concat_2_axis_attr = res.Attr(
231-
[](const paddle::drr::MatchContext &match_ctx) -> int { return 0; });
232-
const auto &concat_2 =
233-
res.Op("pd_op.concat", {{"axis", concat_2_axis_attr}});
228+
229+
const auto &concat_2 = res.Op("pd_op.concat", {{"axis", res.Int32Attr(0)}});
234230
res.Tensor("concat_2_out") = concat_2(res.Tensor("combine_2_out"));
235231

236232
const auto &head_number =
237-
res.Attr([](const paddle::drr::MatchContext &match_ctx) -> int {
233+
res.ComputeAttr([](const paddle::drr::MatchContext &match_ctx) -> int {
238234
const auto &full_int_array_1_value =
239235
match_ctx.Attr<std::vector<int64_t>>("full_int_array_1_value");
240236
return full_int_array_1_value.at(2);
241237
});
242-
const auto &alpha =
243-
res.Attr([](const paddle::drr::MatchContext &match_ctx) -> float {
238+
const auto &alpha = res.ComputeAttr(
239+
[](const paddle::drr::MatchContext &match_ctx) -> float {
244240
return match_ctx.Attr<float>("full_1_value");
245241
});
246-
const auto &multihead_matmul =
247-
res.Op("pd_op.multihead_matmul",
248-
{{"transpose_q",
249-
res.Attr([](const paddle::drr::MatchContext &match_ctx) {
250-
return false;
251-
})},
252-
{"transpose_k",
253-
res.Attr([](const paddle::drr::MatchContext &match_ctx) {
254-
return true;
255-
})},
256-
{"transpose_v",
257-
res.Attr([](const paddle::drr::MatchContext &match_ctx) {
258-
return false;
259-
})},
260-
{"head_number", head_number},
261-
{"alpha", alpha}});
242+
const auto &multihead_matmul = res.Op("pd_op.multihead_matmul",
243+
{{"transpose_q", res.BoolAttr(false)},
244+
{"transpose_k", res.BoolAttr(true)},
245+
{"transpose_v", res.BoolAttr(false)},
246+
{"head_number", head_number},
247+
{"alpha", alpha}});
262248
multihead_matmul({&res.Tensor("matmul_1_in_1"),
263249
&res.Tensor("concat_1_out"),
264250
&res.Tensor("concat_2_out"),
@@ -423,7 +409,7 @@ class MultiHeadMatmulFuseWithBiasQKPattern
423409
paddle::drr::ResultPattern res = src.ResultPattern();
424410

425411
// W reshape.
426-
const auto &reshape_w_shape_attr = res.Attr(
412+
const auto &reshape_w_shape_attr = res.ComputeAttr(
427413
[](const paddle::drr::MatchContext &match_ctx) -> std::vector<int64_t> {
428414
auto matmul_1_in_2 =
429415
pir::GetShapeFromValue(match_ctx.Tensor("matmul_1_in_2"));
@@ -448,14 +434,12 @@ class MultiHeadMatmulFuseWithBiasQKPattern
448434
&res.Tensor("reshape_6_out"),
449435
&res.Tensor("reshape_7_out")},
450436
{&res.Tensor("combine_1_out")});
451-
const auto &concat_1_axis_attr = res.Attr(
452-
[](const paddle::drr::MatchContext &match_ctx) -> int { return 1; });
453-
const auto &concat_1 =
454-
res.Op("pd_op.concat", {{"axis", concat_1_axis_attr}});
437+
438+
const auto &concat_1 = res.Op("pd_op.concat", {{"axis", res.Int32Attr(1)}});
455439
res.Tensor("concat_1_out") = concat_1(res.Tensor("combine_1_out"));
456440

457441
// Bias reshape.
458-
const auto &reshape_b_shape_attr = res.Attr(
442+
const auto &reshape_b_shape_attr = res.ComputeAttr(
459443
[](const paddle::drr::MatchContext &match_ctx) -> std::vector<int64_t> {
460444
auto add_1_in_2 =
461445
pir::GetShapeFromValue(match_ctx.Tensor("add_1_in_2"));
@@ -480,38 +464,26 @@ class MultiHeadMatmulFuseWithBiasQKPattern
480464
&res.Tensor("reshape_9_out"),
481465
&res.Tensor("reshape_10_out")},
482466
{&res.Tensor("combine_2_out")});
483-
const auto &concat_2_axis_attr = res.Attr(
484-
[](const paddle::drr::MatchContext &match_ctx) -> int { return 0; });
485-
const auto &concat_2 =
486-
res.Op("pd_op.concat", {{"axis", concat_2_axis_attr}});
467+
468+
const auto &concat_2 = res.Op("pd_op.concat", {{"axis", res.Int32Attr(0)}});
487469
res.Tensor("concat_2_out") = concat_2(res.Tensor("combine_2_out"));
488470

489471
const auto &head_number =
490-
res.Attr([](const paddle::drr::MatchContext &match_ctx) -> int {
472+
res.ComputeAttr([](const paddle::drr::MatchContext &match_ctx) -> int {
491473
const auto &full_int_array_1_value =
492474
match_ctx.Attr<std::vector<int64_t>>("full_int_array_1_value");
493475
return full_int_array_1_value.at(2);
494476
});
495-
const auto &alpha =
496-
res.Attr([](const paddle::drr::MatchContext &match_ctx) -> float {
477+
const auto &alpha = res.ComputeAttr(
478+
[](const paddle::drr::MatchContext &match_ctx) -> float {
497479
return match_ctx.Attr<float>("full_1_value");
498480
});
499-
const auto &multihead_matmul =
500-
res.Op("pd_op.multihead_matmul",
501-
{{"transpose_q",
502-
res.Attr([](const paddle::drr::MatchContext &match_ctx) {
503-
return false;
504-
})},
505-
{"transpose_k",
506-
res.Attr([](const paddle::drr::MatchContext &match_ctx) {
507-
return true;
508-
})},
509-
{"transpose_v",
510-
res.Attr([](const paddle::drr::MatchContext &match_ctx) {
511-
return false;
512-
})},
513-
{"head_number", head_number},
514-
{"alpha", alpha}});
481+
const auto &multihead_matmul = res.Op("pd_op.multihead_matmul",
482+
{{"transpose_q", res.BoolAttr(false)},
483+
{"transpose_k", res.BoolAttr(true)},
484+
{"transpose_v", res.BoolAttr(false)},
485+
{"head_number", head_number},
486+
{"alpha", alpha}});
515487
multihead_matmul({&res.Tensor("matmul_1_in_1"),
516488
&res.Tensor("concat_1_out"),
517489
&res.Tensor("concat_2_out"),

paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.cc

Lines changed: 15 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -41,34 +41,21 @@ class Conv2dAddFusePattern : public paddle::drr::DrrPatternBase {
4141

4242
paddle::drr::ResultPattern res = pat.ResultPattern();
4343

44-
const auto &fused_conv2d_add_act = res.Op(
45-
paddle::dialect::FusedConv2dAddActOp::name(),
46-
{{
47-
{"strides", pat.Attr("strides")},
48-
{"paddings", pat.Attr("paddings")},
49-
{"padding_algorithm", pat.Attr("padding_algorithm")},
50-
{"dilations", pat.Attr("dilations")},
51-
{"groups", pat.Attr("groups")},
52-
{"data_format", pat.Attr("data_format")},
53-
{"activation",
54-
res.Attr([](const paddle::drr::MatchContext &match_ctx)
55-
-> std::string { return "identity"; })},
56-
{"split_channels",
57-
res.Attr([](const paddle::drr::MatchContext &match_ctx)
58-
-> std::vector<int> { return {}; })},
59-
{"exhaustive_search",
60-
res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool {
61-
return false;
62-
})},
63-
{"workspace_size_MB",
64-
res.Attr([](const paddle::drr::MatchContext &match_ctx) -> int {
65-
return 32;
66-
})},
67-
{"fuse_alpha",
68-
res.Attr([](const paddle::drr::MatchContext &match_ctx) -> float {
69-
return 0.0f;
70-
})},
71-
}});
44+
const auto &fused_conv2d_add_act =
45+
res.Op(paddle::dialect::FusedConv2dAddActOp::name(),
46+
{{
47+
{"strides", pat.Attr("strides")},
48+
{"paddings", pat.Attr("paddings")},
49+
{"padding_algorithm", pat.Attr("padding_algorithm")},
50+
{"dilations", pat.Attr("dilations")},
51+
{"groups", pat.Attr("groups")},
52+
{"data_format", pat.Attr("data_format")},
53+
{"activation", res.StrAttr("identity")},
54+
{"split_channels", res.VectorInt32Attr({})},
55+
{"exhaustive_search", res.BoolAttr(false)},
56+
{"workspace_size_MB", res.Int32Attr(32)},
57+
{"fuse_alpha", res.Float32Attr(0.0f)},
58+
}});
7259

7360
fused_conv2d_add_act({&res.Tensor("input"),
7461
&res.Tensor("filter"),

paddle/fluid/pir/transforms/fusion/fc_elementwise_layernorm_fuse_pass.cc

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,8 @@ class FcElementwiseLayerNormFusePattern : public paddle::drr::DrrPatternBase {
6565

6666
paddle::drr::ResultPattern res = pat.ResultPattern();
6767

68-
const auto &x_num_col_dims_attr =
69-
res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any {
70-
return 1;
71-
});
72-
const auto &false_attr =
73-
res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool {
74-
return false;
75-
});
68+
const auto &x_num_col_dims_attr = res.Int32Attr(1);
69+
const auto &false_attr = res.BoolAttr(false);
7670

7771
const auto &fused_fc_elementwise_op =
7872
res.Op(paddle::dialect::FusedFcElementwiseLayernormOp::name(),

paddle/fluid/pir/transforms/fusion/fc_fuse_pass.cc

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -58,24 +58,18 @@ class MatmulAddPattern : public paddle::drr::DrrPatternBase {
5858
paddle::drr::ResultPattern res = pat.ResultPattern();
5959

6060
const auto &in_num_col_dims_attr =
61-
res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any {
61+
res.ComputeAttr([](const paddle::drr::MatchContext &match_ctx) -> int {
6262
auto x_dims = pir::GetShapeFromValue(match_ctx.Tensor("x"));
6363
return x_dims.size() - 1;
6464
});
65-
const auto &false_attr =
66-
res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool {
67-
return false;
68-
});
69-
70-
const auto &fc =
71-
res.Op(paddle::dialect::FcOp::name(),
72-
{{
73-
{"in_num_col_dims", in_num_col_dims_attr},
74-
{"activation_type",
75-
res.Attr([](const paddle::drr::MatchContext &match_ctx)
76-
-> std::string { return ""; })},
77-
{"padding_weights", false_attr},
78-
}});
65+
const auto &false_attr = res.BoolAttr(false);
66+
67+
const auto &fc = res.Op(paddle::dialect::FcOp::name(),
68+
{{
69+
{"in_num_col_dims", in_num_col_dims_attr},
70+
{"activation_type", res.StrAttr("")},
71+
{"padding_weights", false_attr},
72+
}});
7973
fc({&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("y")},
8074
{&res.Tensor("add_out")});
8175
}
@@ -110,9 +104,7 @@ class FcWithReluPattern : public paddle::drr::DrrPatternBase {
110104
res.Op(paddle::dialect::FcOp::name(),
111105
{{
112106
{"in_num_col_dims", pat.Attr("in_num_col_dims")},
113-
{"activation_type",
114-
res.Attr([](const paddle::drr::MatchContext &match_ctx)
115-
-> std::string { return "relu"; })},
107+
{"activation_type", res.StrAttr("relu")},
116108
{"padding_weights", pat.Attr("padding_weights")},
117109
}});
118110
fc_with_relu({&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("y")},

0 commit comments

Comments
 (0)