From f4a62d15ed30ab8d9a5feba3db5d2a17f95809b6 Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Wed, 17 Jan 2024 08:12:41 +0000 Subject: [PATCH 1/8] pir support attention_fuse_pass to fuse a multihead_matmul op --- .../fluid/inference/api/analysis_predictor.cc | 2 + .../pir/transforms/constant_folding_pass.cc | 110 ++++-- .../transforms/fusion/attention_fuse_pass.cc | 367 ++++++++++++++++-- .../params_sync_among_devices_pass.cc | 4 + test/cpp/pir/pattern_rewrite/CMakeLists.txt | 10 +- .../drr_attention_fuse_test.cc | 10 +- 6 files changed, 433 insertions(+), 70 deletions(-) diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 995e712b562311..a74d811002bee2 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -107,6 +107,7 @@ #include "paddle/fluid/ir_adaptor/translator/translate.h" #include "paddle/fluid/pir/transforms/constant_folding_pass.h" #include "paddle/fluid/pir/transforms/dead_code_elimination_pass.h" +#include "paddle/fluid/pir/transforms/fusion/attention_fuse_pass.h" #include "paddle/fluid/pir/transforms/fusion/conv2d_add_act_fuse_pass.h" #include "paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.h" #include "paddle/fluid/pir/transforms/fusion/conv2d_bn_fuse_pass.h" @@ -804,6 +805,7 @@ bool AnalysisPredictor::PrepareExecutor() { gpu_pm.AddPass(::pir::CreateConv2dBnFusePass()); gpu_pm.AddPass(::pir::CreateConv2dAddActFusePass()); gpu_pm.AddPass(::pir::CreateConv2dAddFusePass()); + gpu_pm.AddPass(::pir::CreateAttentionFusePass()); gpu_pm.AddPass(::pir::CreateFcFusePass()); gpu_pm.AddPass(::pir::CreateFcElementwiseLayerNormFusePass()); gpu_pm.AddPass(::pir::CreateMatmulScaleFusePass()); diff --git a/paddle/fluid/pir/transforms/constant_folding_pass.cc b/paddle/fluid/pir/transforms/constant_folding_pass.cc index cca083f7090252..41e39ef050f321 100644 --- a/paddle/fluid/pir/transforms/constant_folding_pass.cc +++ b/paddle/fluid/pir/transforms/constant_folding_pass.cc @@ -34,6 +34,7 @@ #include "paddle/pir/core/builtin_attribute.h" #include "paddle/pir/core/builtin_op.h" +#include "paddle/pir/core/builtin_type.h" #include "paddle/pir/core/ir_context.h" #include "paddle/pir/core/op_result.h" #include "paddle/pir/core/op_trait.h" @@ -83,17 +84,36 @@ class ConstantFoldingPattern : public pir::RewritePattern { if (!op->operand_source(i) || !op->operand_source(i).type()) { continue; } - // 2. inputs must come from parameter op or constant op + // 2. inputs must come from ParameterOp/ConstantTensorOp/CombineOp auto* prev_op = pir::GetDefiningOpForInput(op, i); if (!prev_op || !(prev_op->isa() || - prev_op->isa())) { + prev_op->isa() || + prev_op->isa())) { return false; } - // 3. inputs must be a dense tensor type - if (!op->operand_source(i) - .type() - .isa()) { - return false; + if (prev_op->isa()) { + if (prev_op->result(0).use_count() > 1) { + return false; + } + for (uint32_t i = 0; i < prev_op->num_operands(); i++) { + if (!prev_op->operand_source(i) || + !prev_op->operand_source(i).type()) { + continue; + } + if (!prev_op->operand_source(i) + .type() + .isa()) { + return false; + } + } + + } else { + // 3. inputs must be a dense tensor type + if (!op->operand_source(i) + .type() + .isa()) { + return false; + } } } @@ -233,7 +253,7 @@ class ConstantFoldingPattern : public pir::RewritePattern { BuildProgramFromOperation(op, &new_program, rewriter); // execute program - for (auto output_var_name : output_var_names) { + for (const auto& output_var_name : output_var_names) { exe_config_->skip_gc_vars.insert(output_var_name); } auto kernel_program = @@ -256,24 +276,52 @@ class ConstantFoldingPattern : public pir::RewritePattern { std::vector op_inputs; for (uint32_t i = 0; i < op->num_operands(); i++) { if (op->operand_source(i)) { - const auto& param_name = - pir::GetParameterNameFromValue(op->operand_source(i)); - auto* param_var = scope_->FindVar(param_name); - PADDLE_ENFORCE_NOT_NULL( - param_var, - phi::errors::InvalidArgument("Parameter var [%s] not in scope.", - param_name)); - - auto parameter_op = builder.Build( - param_name, op->operand_source(i).type()); - if (op->operand_source(i).use_count() <= 1) { - deleted_vars_->push_back(param_name); + auto* prev_op = pir::GetDefiningOpForInput(op, i); + if (prev_op->isa()) { + // prepare combine op inputs + std::vector combine_op_inputs; + for (uint32_t i = 0; i < prev_op->num_operands(); i++) { + const auto& param_name = + pir::GetParameterNameFromValue(prev_op->operand_source(i)); + auto* param_var = scope_->FindVar(param_name); + PADDLE_ENFORCE_NOT_NULL( + param_var, + phi::errors::InvalidArgument("Parameter var [%s] not in scope.", + param_name)); + + auto parameter_op = builder.Build( + param_name, prev_op->operand_source(i).type()); + if (prev_op->operand_source(i).use_count() <= 1) { + deleted_vars_->push_back(param_name); + } else { + parameter_op->set_attribute( + kAttrIsPersisable, + rewriter.array_attr({rewriter.bool_attr(true)})); + } + combine_op_inputs.push_back(parameter_op->result(0)); + } + auto combine_op = builder.Build(combine_op_inputs); + op_inputs.push_back(combine_op->result(0)); } else { - parameter_op->set_attribute( - kAttrIsPersisable, - rewriter.array_attr({rewriter.bool_attr(true)})); + const auto& param_name = + pir::GetParameterNameFromValue(op->operand_source(i)); + auto* param_var = scope_->FindVar(param_name); + PADDLE_ENFORCE_NOT_NULL( + param_var, + phi::errors::InvalidArgument("Parameter var [%s] not in scope.", + param_name)); + + auto parameter_op = builder.Build( + param_name, op->operand_source(i).type()); + if (op->operand_source(i).use_count() <= 1) { + deleted_vars_->push_back(param_name); + } else { + parameter_op->set_attribute( + kAttrIsPersisable, + rewriter.array_attr({rewriter.bool_attr(true)})); + } + op_inputs.push_back(parameter_op->result(0)); } - op_inputs.push_back(parameter_op->result(0)); } else { op_inputs.push_back( op->operand_source(i).dyn_cast() /*nullptr*/); @@ -281,17 +329,17 @@ class ConstantFoldingPattern : public pir::RewritePattern { } // prepare op outputs - std::vector output_types; + std::vector op_output_types; for (uint32_t i = 0; i < op->num_results(); i++) { - output_types.push_back(op->result(i).type()); + op_output_types.push_back(op->result(i).type()); } - auto* temp_op = - builder.Build(op_inputs, op->attributes(), output_types, op->info()); + auto* op_copy = + builder.Build(op_inputs, op->attributes(), op_output_types, op->info()); std::vector output_var_names; - for (uint32_t i = 0; i < op->num_results(); i++) { - if (!temp_op->result(i) || !temp_op->result(i).type()) { + for (uint32_t i = 0; i < op_copy->num_results(); i++) { + if (!op_copy->result(i) || !op_copy->result(i).type()) { continue; } std::stringstream ss; @@ -301,7 +349,7 @@ class ConstantFoldingPattern : public pir::RewritePattern { std::string output_var_name = "constant_folding@_" + ss.str() + std::to_string((*suffix_)++); - builder.Build(temp_op->result(i), output_var_name); + builder.Build(op_copy->result(i), output_var_name); output_var_names.push_back(output_var_name); } diff --git a/paddle/fluid/pir/transforms/fusion/attention_fuse_pass.cc b/paddle/fluid/pir/transforms/fusion/attention_fuse_pass.cc index 616ff6f607c588..4d670d92f4aa8d 100644 --- a/paddle/fluid/pir/transforms/fusion/attention_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/attention_fuse_pass.cc @@ -21,7 +21,7 @@ namespace { -class MultiHeadMatmulFusePattern : public paddle::drr::DrrPatternBase { +class MultiHeadMatmulFuseNoBiasQKPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { // @@ -92,12 +92,10 @@ class MultiHeadMatmulFusePattern : public paddle::drr::DrrPatternBase { {"transpose_y", src.Attr("matmul_4_transpose_y")}}); src.Tensor("matmul_4_out") = matmul_4(src.Tensor("scale_out"), src.Tensor("transpose_2_out")); - const auto &add_4 = src.Op("pd_op.add"); - src.Tensor("add_4_out") = - add_4(src.Tensor("matmul_4_out"), src.Tensor("add_4_in_2")); + const auto &softmax = src.Op("pd_op.softmax", {{"axis", src.Attr("softmax_axis")}}); - src.Tensor("softmax_out") = softmax(src.Tensor("add_4_out")); + src.Tensor("softmax_out") = softmax(src.Tensor("matmul_4_out")); const auto &matmul_5 = src.Op("pd_op.matmul", {{"transpose_x", src.Attr("matmul_5_transpose_x")}, @@ -139,6 +137,30 @@ class MultiHeadMatmulFusePattern : public paddle::drr::DrrPatternBase { bool matmul_5_transpose_y = match_ctx.Attr("matmul_5_transpose_y"); if (matmul_5_transpose_x || matmul_5_transpose_y) return false; + auto matmul_1_in_2 = + pir::GetShapeFromValue(match_ctx.Tensor("matmul_1_in_2")); + auto matmul_2_in_2 = + pir::GetShapeFromValue(match_ctx.Tensor("matmul_2_in_2")); + auto matmul_3_in_2 = + pir::GetShapeFromValue(match_ctx.Tensor("matmul_3_in_2")); + if (matmul_1_in_2.size() != 2 || matmul_2_in_2.size() != 2 || + matmul_3_in_2.size() != 2 || + matmul_1_in_2.at(0) != matmul_2_in_2.at(0) || + matmul_1_in_2.at(0) != matmul_3_in_2.at(0) || + matmul_1_in_2.at(1) != matmul_2_in_2.at(1) || + matmul_1_in_2.at(1) != matmul_3_in_2.at(1)) { + return false; + } + + auto add_1_in_2 = pir::GetShapeFromValue(match_ctx.Tensor("add_1_in_2")); + auto add_2_in_2 = pir::GetShapeFromValue(match_ctx.Tensor("add_2_in_2")); + auto add_3_in_2 = pir::GetShapeFromValue(match_ctx.Tensor("add_3_in_2")); + if (add_1_in_2.size() != 1 || add_2_in_2.size() != 1 || + add_3_in_2.size() != 1 || add_1_in_2.at(0) != add_2_in_2.at(0) || + add_1_in_2.at(0) != add_3_in_2.at(0)) { + return false; + } + return true; }); @@ -146,43 +168,323 @@ class MultiHeadMatmulFusePattern : public paddle::drr::DrrPatternBase { // Result Pattern. // paddle::drr::ResultPattern res = src.ResultPattern(); - // W combine. - const auto &combine_1 = res.Op("builtin.combine"); - combine_1({&res.Tensor("matmul_1_in_2"), - &res.Tensor("matmul_2_in_2"), - &res.Tensor("matmul_3_in_2")}, - {&res.Tensor("combine_1_out")}); - const auto &concat_axis = res.Attr( - [](const paddle::drr::MatchContext &match_ctx) -> int { return 0; }); - const auto &concat_1 = res.Op("pd_op.concat", {{"axis", concat_axis}}); - res.Tensor("concat_1_out") = concat_1(res.Tensor("combine_1_out")); - const auto &reshape_5_shape = res.Attr( + + // W reshape. + const auto &reshape_w_shape_attr = res.Attr( [](const paddle::drr::MatchContext &match_ctx) -> std::vector { auto matmul_1_in_2 = pir::GetShapeFromValue(match_ctx.Tensor("matmul_1_in_2")); - return {-1, 3, matmul_1_in_2.at(1)}; + return {matmul_1_in_2.at(0), 1, matmul_1_in_2.at(1)}; }); const auto &reshape_5 = - res.Op("pd_op.reshape", {{"shape", reshape_5_shape}}); - reshape_5({&res.Tensor("concat_1_out")}, + res.Op("pd_op.reshape", {{"shape", reshape_w_shape_attr}}); + reshape_5({&res.Tensor("matmul_1_in_2")}, {&res.Tensor("reshape_5_out"), &res.NoneTensor()}); + const auto &reshape_6 = + res.Op("pd_op.reshape", {{"shape", reshape_w_shape_attr}}); + reshape_6({&res.Tensor("matmul_2_in_2")}, + {&res.Tensor("reshape_6_out"), &res.NoneTensor()}); + const auto &reshape_7 = + res.Op("pd_op.reshape", {{"shape", reshape_w_shape_attr}}); + reshape_7({&res.Tensor("matmul_3_in_2")}, + {&res.Tensor("reshape_7_out"), &res.NoneTensor()}); + + // W combine. + const auto &combine_1 = res.Op("builtin.combine"); + combine_1({&res.Tensor("reshape_5_out"), + &res.Tensor("reshape_6_out"), + &res.Tensor("reshape_7_out")}, + {&res.Tensor("combine_1_out")}); + const auto &concat_1_axis_attr = res.Attr( + [](const paddle::drr::MatchContext &match_ctx) -> int { return 1; }); + const auto &concat_1 = + res.Op("pd_op.concat", {{"axis", concat_1_axis_attr}}); + res.Tensor("concat_1_out") = concat_1(res.Tensor("combine_1_out")); + + // Bias reshape. + const auto &reshape_b_shape_attr = res.Attr( + [](const paddle::drr::MatchContext &match_ctx) -> std::vector { + auto add_1_in_2 = + pir::GetShapeFromValue(match_ctx.Tensor("add_1_in_2")); + return {1, add_1_in_2.at(0)}; + }); + const auto &reshape_8 = + res.Op("pd_op.reshape", {{"shape", reshape_b_shape_attr}}); + reshape_8({&res.Tensor("add_1_in_2")}, + {&res.Tensor("reshape_8_out"), &res.NoneTensor()}); + const auto &reshape_9 = + res.Op("pd_op.reshape", {{"shape", reshape_b_shape_attr}}); + reshape_9({&res.Tensor("add_2_in_2")}, + {&res.Tensor("reshape_9_out"), &res.NoneTensor()}); + const auto &reshape_10 = + res.Op("pd_op.reshape", {{"shape", reshape_b_shape_attr}}); + reshape_10({&res.Tensor("add_3_in_2")}, + {&res.Tensor("reshape_10_out"), &res.NoneTensor()}); // Bias combine. const auto &combine_2 = res.Op("builtin.combine"); - combine_2({&res.Tensor("add_1_in_2"), - &res.Tensor("add_2_in_2"), - &res.Tensor("add_3_in_2")}, + combine_2({&res.Tensor("reshape_8_out"), + &res.Tensor("reshape_9_out"), + &res.Tensor("reshape_10_out")}, {&res.Tensor("combine_2_out")}); - const auto &concat_2 = res.Op("pd_op.concat", {{"axis", concat_axis}}); + const auto &concat_2_axis_attr = res.Attr( + [](const paddle::drr::MatchContext &match_ctx) -> int { return 0; }); + const auto &concat_2 = + res.Op("pd_op.concat", {{"axis", concat_2_axis_attr}}); res.Tensor("concat_2_out") = concat_2(res.Tensor("combine_2_out")); - const auto &reshape_6_shape = res.Attr( + + const auto &head_number = + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> int { + const auto &full_int_array_1_value = + match_ctx.Attr>("full_int_array_1_value"); + return full_int_array_1_value.at(2); + }); + const auto &alpha = + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> float { + return match_ctx.Attr("full_1_value"); + }); + const auto &multihead_matmul = + res.Op("pd_op.multihead_matmul", + {{"transpose_q", + res.Attr([](const paddle::drr::MatchContext &match_ctx) { + return false; + })}, + {"transpose_k", + res.Attr([](const paddle::drr::MatchContext &match_ctx) { + return true; + })}, + {"transpose_v", + res.Attr([](const paddle::drr::MatchContext &match_ctx) { + return false; + })}, + {"head_number", head_number}, + {"alpha", alpha}}); + multihead_matmul({&res.Tensor("matmul_1_in_1"), + &res.Tensor("concat_1_out"), + &res.Tensor("concat_2_out"), + &res.NoneTensor()}, + {&res.Tensor("reshape_4_out")}); + } + + std::string name() const override { + return "MultiHeadMatmulFuseNoBiasQKPattern"; + } +}; + +class MultiHeadMatmulFuseWithBiasQKPattern + : public paddle::drr::DrrPatternBase { + public: + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + // + // Source Pattern. + // + paddle::drr::SourcePattern src = ctx->SourcePattern(); + // The first path to matmul with scale (q). + const auto &matmul_1 = + src.Op("pd_op.matmul", + {{"transpose_x", src.Attr("matmul_1_transpose_x")}, + {"transpose_y", src.Attr("matmul_1_transpose_y")}}); + src.Tensor("matmul_1_out") = + matmul_1(src.Tensor("matmul_1_in_1"), src.Tensor("matmul_1_in_2")); + const auto &add_1 = src.Op("pd_op.add"); + src.Tensor("add_1_out") = + add_1(src.Tensor("matmul_1_out"), src.Tensor("add_1_in_2")); + const auto &full_int_array_1 = + src.Op("pd_op.full_int_array", + {{"value", src.Attr("full_int_array_1_value")}}); + const auto &reshape_1 = src.Op("pd_op.reshape"); + reshape_1({&src.Tensor("add_1_out"), &full_int_array_1()}, + {&src.Tensor("reshape_1_out"), &src.Tensor("reshape_1_xshape")}); + const auto &transpose_1 = src.Op("pd_op.transpose"); + src.Tensor("transpose_1_out") = transpose_1(src.Tensor("reshape_1_out")); + const auto &full_1 = + src.Op("pd_op.full", {{"value", src.Attr("full_1_value")}}); + const auto &scale = src.Op("pd_op.scale"); + src.Tensor("scale_out") = scale(src.Tensor("transpose_1_out"), full_1()); + + // The second path to matmul (k). + const auto &matmul_2 = + src.Op("pd_op.matmul", + {{"transpose_x", src.Attr("matmul_2_transpose_x")}, + {"transpose_y", src.Attr("matmul_2_transpose_y")}}); + src.Tensor("matmul_2_out") = + matmul_2(src.Tensor("matmul_1_in_1"), src.Tensor("matmul_2_in_2")); + const auto &add_2 = src.Op("pd_op.add"); + src.Tensor("add_2_out") = + add_2(src.Tensor("matmul_2_out"), src.Tensor("add_2_in_2")); + const auto &full_int_array_2 = src.Op("pd_op.full_int_array"); + const auto &reshape_2 = src.Op("pd_op.reshape"); + reshape_2({&src.Tensor("add_2_out"), &full_int_array_2()}, + {&src.Tensor("reshape_2_out"), &src.Tensor("reshape_2_xshape")}); + const auto &transpose_2 = src.Op("pd_op.transpose"); + src.Tensor("transpose_2_out") = transpose_2(src.Tensor("reshape_2_out")); + + // The third path to matmul (v). + const auto &matmul_3 = + src.Op("pd_op.matmul", + {{"transpose_x", src.Attr("matmul_3_transpose_x")}, + {"transpose_y", src.Attr("matmul_3_transpose_y")}}); + src.Tensor("matmul_3_out") = + matmul_3(src.Tensor("matmul_1_in_1"), src.Tensor("matmul_3_in_2")); + const auto &add_3 = src.Op("pd_op.add"); + src.Tensor("add_3_out") = + add_3(src.Tensor("matmul_3_out"), src.Tensor("add_3_in_2")); + const auto &full_int_array_3 = src.Op("pd_op.full_int_array"); + const auto &reshape_3 = src.Op("pd_op.reshape"); + reshape_3({&src.Tensor("add_3_out"), &full_int_array_3()}, + {&src.Tensor("reshape_3_out"), &src.Tensor("reshape_3_xshape")}); + const auto &transpose_3 = src.Op("pd_op.transpose"); + src.Tensor("transpose_3_out") = transpose_3(src.Tensor("reshape_3_out")); + + // softmax(qk)v + const auto &matmul_4 = + src.Op("pd_op.matmul", + {{"transpose_x", src.Attr("matmul_4_transpose_x")}, + {"transpose_y", src.Attr("matmul_4_transpose_y")}}); + src.Tensor("matmul_4_out") = + matmul_4(src.Tensor("scale_out"), src.Tensor("transpose_2_out")); + const auto &add_4 = src.Op("pd_op.add"); + src.Tensor("add_4_out") = + add_4(src.Tensor("matmul_4_out"), src.Tensor("add_4_in_2")); + const auto &softmax = + src.Op("pd_op.softmax", {{"axis", src.Attr("softmax_axis")}}); + src.Tensor("softmax_out") = softmax(src.Tensor("add_4_out")); + const auto &matmul_5 = + src.Op("pd_op.matmul", + {{"transpose_x", src.Attr("matmul_5_transpose_x")}, + {"transpose_y", src.Attr("matmul_5_transpose_y")}}); + src.Tensor("matmul_5_out") = + matmul_5(src.Tensor("softmax_out"), src.Tensor("transpose_3_out")); + const auto &transpose_4 = src.Op("pd_op.transpose"); + src.Tensor("transpose_4_out") = transpose_4(src.Tensor("matmul_5_out")); + const auto &full_int_array_4 = src.Op("pd_op.full_int_array"); + const auto &reshape_4 = src.Op("pd_op.reshape"); + reshape_4({&src.Tensor("transpose_4_out"), &full_int_array_4()}, + {&src.Tensor("reshape_4_out"), &src.Tensor("reshape_4_xshape")}); + + // + // Constraints. + // + src.RequireNativeCall([](const paddle::drr::MatchContext &match_ctx) + -> bool { + const auto &softmax_axis = match_ctx.Attr("softmax_axis"); + if (softmax_axis != -1 && softmax_axis != 3) return false; + + bool matmul_1_transpose_x = match_ctx.Attr("matmul_1_transpose_x"); + bool matmul_1_transpose_y = match_ctx.Attr("matmul_1_transpose_y"); + if (matmul_1_transpose_x || matmul_1_transpose_y) return false; + + bool matmul_2_transpose_x = match_ctx.Attr("matmul_2_transpose_x"); + bool matmul_2_transpose_y = match_ctx.Attr("matmul_2_transpose_y"); + if (matmul_2_transpose_x || matmul_2_transpose_y) return false; + + bool matmul_3_transpose_x = match_ctx.Attr("matmul_3_transpose_x"); + bool matmul_3_transpose_y = match_ctx.Attr("matmul_3_transpose_y"); + if (matmul_3_transpose_x || matmul_3_transpose_y) return false; + + bool matmul_4_transpose_x = match_ctx.Attr("matmul_4_transpose_x"); + bool matmul_4_transpose_y = match_ctx.Attr("matmul_4_transpose_y"); + if (matmul_4_transpose_x || !matmul_4_transpose_y) return false; + + bool matmul_5_transpose_x = match_ctx.Attr("matmul_5_transpose_x"); + bool matmul_5_transpose_y = match_ctx.Attr("matmul_5_transpose_y"); + if (matmul_5_transpose_x || matmul_5_transpose_y) return false; + + auto matmul_1_in_2 = + pir::GetShapeFromValue(match_ctx.Tensor("matmul_1_in_2")); + auto matmul_2_in_2 = + pir::GetShapeFromValue(match_ctx.Tensor("matmul_2_in_2")); + auto matmul_3_in_2 = + pir::GetShapeFromValue(match_ctx.Tensor("matmul_3_in_2")); + if (matmul_1_in_2.size() != 2 || matmul_2_in_2.size() != 2 || + matmul_3_in_2.size() != 2 || + matmul_1_in_2.at(0) != matmul_2_in_2.at(0) || + matmul_1_in_2.at(0) != matmul_3_in_2.at(0) || + matmul_1_in_2.at(1) != matmul_2_in_2.at(1) || + matmul_1_in_2.at(1) != matmul_3_in_2.at(1)) { + return false; + } + + auto add_1_in_2 = pir::GetShapeFromValue(match_ctx.Tensor("add_1_in_2")); + auto add_2_in_2 = pir::GetShapeFromValue(match_ctx.Tensor("add_2_in_2")); + auto add_3_in_2 = pir::GetShapeFromValue(match_ctx.Tensor("add_3_in_2")); + if (add_1_in_2.size() != 1 || add_2_in_2.size() != 1 || + add_3_in_2.size() != 1 || add_1_in_2.at(0) != add_2_in_2.at(0) || + add_1_in_2.at(0) != add_3_in_2.at(0)) { + return false; + } + + return true; + }); + + // + // Result Pattern. + // + paddle::drr::ResultPattern res = src.ResultPattern(); + + // W reshape. + const auto &reshape_w_shape_attr = res.Attr( [](const paddle::drr::MatchContext &match_ctx) -> std::vector { - return {3, -1}; + auto matmul_1_in_2 = + pir::GetShapeFromValue(match_ctx.Tensor("matmul_1_in_2")); + return {matmul_1_in_2.at(0), 1, matmul_1_in_2.at(1)}; }); + const auto &reshape_5 = + res.Op("pd_op.reshape", {{"shape", reshape_w_shape_attr}}); + reshape_5({&res.Tensor("matmul_1_in_2")}, + {&res.Tensor("reshape_5_out"), &res.NoneTensor()}); const auto &reshape_6 = - res.Op("pd_op.reshape", {{"shape", reshape_6_shape}}); - reshape_6({&res.Tensor("concat_2_out")}, + res.Op("pd_op.reshape", {{"shape", reshape_w_shape_attr}}); + reshape_6({&res.Tensor("matmul_2_in_2")}, {&res.Tensor("reshape_6_out"), &res.NoneTensor()}); + const auto &reshape_7 = + res.Op("pd_op.reshape", {{"shape", reshape_w_shape_attr}}); + reshape_7({&res.Tensor("matmul_3_in_2")}, + {&res.Tensor("reshape_7_out"), &res.NoneTensor()}); + + // W combine. + const auto &combine_1 = res.Op("builtin.combine"); + combine_1({&res.Tensor("reshape_5_out"), + &res.Tensor("reshape_6_out"), + &res.Tensor("reshape_7_out")}, + {&res.Tensor("combine_1_out")}); + const auto &concat_1_axis_attr = res.Attr( + [](const paddle::drr::MatchContext &match_ctx) -> int { return 1; }); + const auto &concat_1 = + res.Op("pd_op.concat", {{"axis", concat_1_axis_attr}}); + res.Tensor("concat_1_out") = concat_1(res.Tensor("combine_1_out")); + + // Bias reshape. + const auto &reshape_b_shape_attr = res.Attr( + [](const paddle::drr::MatchContext &match_ctx) -> std::vector { + auto add_1_in_2 = + pir::GetShapeFromValue(match_ctx.Tensor("add_1_in_2")); + return {1, add_1_in_2.at(0)}; + }); + const auto &reshape_8 = + res.Op("pd_op.reshape", {{"shape", reshape_b_shape_attr}}); + reshape_8({&res.Tensor("add_1_in_2")}, + {&res.Tensor("reshape_8_out"), &res.NoneTensor()}); + const auto &reshape_9 = + res.Op("pd_op.reshape", {{"shape", reshape_b_shape_attr}}); + reshape_9({&res.Tensor("add_2_in_2")}, + {&res.Tensor("reshape_9_out"), &res.NoneTensor()}); + const auto &reshape_10 = + res.Op("pd_op.reshape", {{"shape", reshape_b_shape_attr}}); + reshape_10({&res.Tensor("add_3_in_2")}, + {&res.Tensor("reshape_10_out"), &res.NoneTensor()}); + + // Bias combine. + const auto &combine_2 = res.Op("builtin.combine"); + combine_2({&res.Tensor("reshape_8_out"), + &res.Tensor("reshape_9_out"), + &res.Tensor("reshape_10_out")}, + {&res.Tensor("combine_2_out")}); + const auto &concat_2_axis_attr = res.Attr( + [](const paddle::drr::MatchContext &match_ctx) -> int { return 0; }); + const auto &concat_2 = + res.Op("pd_op.concat", {{"axis", concat_2_axis_attr}}); + res.Tensor("concat_2_out") = concat_2(res.Tensor("combine_2_out")); const auto &head_number = res.Attr([](const paddle::drr::MatchContext &match_ctx) -> int { @@ -211,13 +513,15 @@ class MultiHeadMatmulFusePattern : public paddle::drr::DrrPatternBase { {"head_number", head_number}, {"alpha", alpha}}); multihead_matmul({&res.Tensor("matmul_1_in_1"), - &res.Tensor("reshape_5_out"), - &res.Tensor("reshape_6_out"), + &res.Tensor("concat_1_out"), + &res.Tensor("concat_2_out"), &res.Tensor("add_4_in_2")}, {&res.Tensor("reshape_4_out")}); } - std::string name() const override { return "MultiHeadMatmulFusePattern"; } + std::string name() const override { + return "MultiHeadMatmulFuseWithBiasQKPattern"; + } }; class AttentionFusePass : public pir::PatternRewritePass { @@ -226,7 +530,8 @@ class AttentionFusePass : public pir::PatternRewritePass { pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override { pir::RewritePatternSet ps(context); - ps.Add(MultiHeadMatmulFusePattern().Build(context)); + ps.Add(MultiHeadMatmulFuseNoBiasQKPattern().Build(context)); + ps.Add(MultiHeadMatmulFuseWithBiasQKPattern().Build(context)); // Add other attention variant fuse pattern. return ps; diff --git a/paddle/fluid/pir/transforms/params_sync_among_devices_pass.cc b/paddle/fluid/pir/transforms/params_sync_among_devices_pass.cc index 41e51d00ef7048..73f00cebb91c62 100644 --- a/paddle/fluid/pir/transforms/params_sync_among_devices_pass.cc +++ b/paddle/fluid/pir/transforms/params_sync_among_devices_pass.cc @@ -89,6 +89,10 @@ class ParamsSyncAmongDevicesPass : public pir::Pass { param_tensor->clear(); paddle::framework::TensorCopySync(temp_tensor, place_, param_tensor); num_rewrites_++; + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "params_sync_among_devices_pass only support DenseTensor type of " + "parameter var.")); } } } diff --git a/test/cpp/pir/pattern_rewrite/CMakeLists.txt b/test/cpp/pir/pattern_rewrite/CMakeLists.txt index b06577552d52b3..359950e796d155 100644 --- a/test/cpp/pir/pattern_rewrite/CMakeLists.txt +++ b/test/cpp/pir/pattern_rewrite/CMakeLists.txt @@ -23,7 +23,9 @@ cc_test( SRCS drr_fuse_linear_param_grad_add_test.cc DEPS pir_transforms drr gtest op_dialect_vjp pir) -cc_test( - drr_attention_fuse_test - SRCS drr_attention_fuse_test.cc - DEPS pir_transforms drr gtest op_dialect_vjp pir) +if(WITH_GPU) + cc_test( + drr_attention_fuse_test + SRCS drr_attention_fuse_test.cc + DEPS pir_transforms drr gtest op_dialect_vjp pir) +endif() diff --git a/test/cpp/pir/pattern_rewrite/drr_attention_fuse_test.cc b/test/cpp/pir/pattern_rewrite/drr_attention_fuse_test.cc index 8485f493c794c2..4361fd03a306f3 100644 --- a/test/cpp/pir/pattern_rewrite/drr_attention_fuse_test.cc +++ b/test/cpp/pir/pattern_rewrite/drr_attention_fuse_test.cc @@ -21,8 +21,10 @@ #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/pir/transforms/constant_folding_pass.h" +#include "paddle/fluid/pir/transforms/dead_code_elimination_pass.h" #include "paddle/fluid/pir/transforms/fusion/attention_fuse_pass.h" +#include "paddle/phi/common/place.h" #include "paddle/pir/core/builtin_dialect.h" #include "paddle/pir/pass/pass_manager.h" @@ -147,13 +149,13 @@ TEST(DrrTest, AttentionFuse) { pm.AddPass(pir::CreateAttentionFusePass()); std::unique_ptr constant_folding_pass = pir::CreateConstantFoldingPass(); - phi::Place place = phi::CPUPlace(); - constant_folding_pass->SetNotOwned(pir::kPlaceAttr, &place); + constant_folding_pass->Set(pir::kPlaceAttr, new phi::Place{phi::GPUPlace{}}); constant_folding_pass->Set(pir::kParamScopeAttr, - new paddle::framework::Scope()); + new paddle::framework::Scope{}); pm.AddPass(std::move(constant_folding_pass)); + pm.AddPass(pir::CreateDeadCodeEliminationPass()); pm.EnableIRPrinting(); CHECK_EQ(pm.Run(&program), true); - EXPECT_EQ(program.block()->size(), 20u); + EXPECT_EQ(program.block()->size(), 2u); } From ae0efed06448765af7fc4ad78786ee62c098500c Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Wed, 17 Jan 2024 08:49:18 +0000 Subject: [PATCH 2/8] update --- .../pir/transforms/constant_folding_pass.cc | 82 +++++++++++++++---- 1 file changed, 65 insertions(+), 17 deletions(-) diff --git a/paddle/fluid/pir/transforms/constant_folding_pass.cc b/paddle/fluid/pir/transforms/constant_folding_pass.cc index 41e39ef050f321..7a7a908bc0f67a 100644 --- a/paddle/fluid/pir/transforms/constant_folding_pass.cc +++ b/paddle/fluid/pir/transforms/constant_folding_pass.cc @@ -28,6 +28,7 @@ #include "paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h" #include "paddle/fluid/pir/transforms/transform_general_functions.h" +#include "paddle/common/errors.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/enforce.h" @@ -281,28 +282,53 @@ class ConstantFoldingPattern : public pir::RewritePattern { // prepare combine op inputs std::vector combine_op_inputs; for (uint32_t i = 0; i < prev_op->num_operands(); i++) { - const auto& param_name = - pir::GetParameterNameFromValue(prev_op->operand_source(i)); - auto* param_var = scope_->FindVar(param_name); - PADDLE_ENFORCE_NOT_NULL( - param_var, - phi::errors::InvalidArgument("Parameter var [%s] not in scope.", - param_name)); - - auto parameter_op = builder.Build( - param_name, prev_op->operand_source(i).type()); - if (prev_op->operand_source(i).use_count() <= 1) { - deleted_vars_->push_back(param_name); + auto* combine_prev_op = pir::GetDefiningOpForInput(prev_op, i); + if (combine_prev_op->isa()) { + const auto& param_name = + pir::GetParameterNameFromValue(prev_op->operand_source(i)); + auto* param_var = scope_->FindVar(param_name); + PADDLE_ENFORCE_NOT_NULL( + param_var, + phi::errors::InvalidArgument( + "Parameter var [%s] not in scope.", param_name)); + + auto parameter_op = builder.Build( + param_name, prev_op->operand_source(i).type()); + if (prev_op->operand_source(i).use_count() <= 1) { + deleted_vars_->push_back(param_name); + } else { + parameter_op->set_attribute( + kAttrIsPersisable, + rewriter.array_attr({rewriter.bool_attr(true)})); + } + combine_op_inputs.push_back(parameter_op->result(0)); + } else if (combine_prev_op->isa()) { + const auto& tensor_name = + pir::GetParameterNameFromValue(prev_op->operand_source(i)); + auto* tensor_var = scope_->FindVar(tensor_name); + PADDLE_ENFORCE_NOT_NULL( + tensor_var, + phi::errors::InvalidArgument("Tensor var [%s] not in scope.", + tensor_name)); + + auto constant_op = builder.Build( + rewriter.tensor_name_attr(tensor_name), + prev_op->operand_source(i).type()); + if (prev_op->operand_source(i).use_count() <= 1) { + deleted_vars_->push_back(tensor_name); + } else { + constant_op->set_attribute( + kAttrIsPersisable, + rewriter.array_attr({rewriter.bool_attr(true)})); + } + combine_op_inputs.push_back(constant_op->result(0)); } else { - parameter_op->set_attribute( - kAttrIsPersisable, - rewriter.array_attr({rewriter.bool_attr(true)})); + PADDLE_THROW(phi::errors::Fatal("Not Support!")); } - combine_op_inputs.push_back(parameter_op->result(0)); } auto combine_op = builder.Build(combine_op_inputs); op_inputs.push_back(combine_op->result(0)); - } else { + } else if (prev_op->isa()) { const auto& param_name = pir::GetParameterNameFromValue(op->operand_source(i)); auto* param_var = scope_->FindVar(param_name); @@ -321,6 +347,28 @@ class ConstantFoldingPattern : public pir::RewritePattern { rewriter.array_attr({rewriter.bool_attr(true)})); } op_inputs.push_back(parameter_op->result(0)); + } else if (prev_op->isa()) { + const auto& tensor_name = + pir::GetParameterNameFromValue(op->operand_source(i)); + auto* tensor_var = scope_->FindVar(tensor_name); + PADDLE_ENFORCE_NOT_NULL( + tensor_var, + phi::errors::InvalidArgument("Tensor var [%s] not in scope.", + tensor_name)); + + auto constant_op = builder.Build( + rewriter.tensor_name_attr(tensor_name), + op->operand_source(i).type()); + if (op->operand_source(i).use_count() <= 1) { + deleted_vars_->push_back(tensor_name); + } else { + constant_op->set_attribute( + kAttrIsPersisable, + rewriter.array_attr({rewriter.bool_attr(true)})); + } + op_inputs.push_back(constant_op->result(0)); + } else { + PADDLE_THROW(phi::errors::Fatal("Not Support!")); } } else { op_inputs.push_back( From 7cc52c2967e6066d93e11fabe1e8318ba9c323a3 Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Wed, 17 Jan 2024 10:38:27 +0000 Subject: [PATCH 3/8] fix --- .../pir/transforms/constant_folding_pass.cc | 139 +++++++----------- paddle/pir/core/builtin_op.cc | 25 ++-- paddle/pir/core/builtin_op.h | 5 + 3 files changed, 77 insertions(+), 92 deletions(-) diff --git a/paddle/fluid/pir/transforms/constant_folding_pass.cc b/paddle/fluid/pir/transforms/constant_folding_pass.cc index 7a7a908bc0f67a..3bfda8e00d636f 100644 --- a/paddle/fluid/pir/transforms/constant_folding_pass.cc +++ b/paddle/fluid/pir/transforms/constant_folding_pass.cc @@ -33,6 +33,7 @@ #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/enforce.h" +#include "paddle/pir/core/builder.h" #include "paddle/pir/core/builtin_attribute.h" #include "paddle/pir/core/builtin_op.h" #include "paddle/pir/core/builtin_type.h" @@ -101,15 +102,21 @@ class ConstantFoldingPattern : public pir::RewritePattern { !prev_op->operand_source(i).type()) { continue; } + // 3. for combine's prev op, inputs must come from + // ParameterOp/ConstantTensorOp + auto* prev_prev_op = pir::GetDefiningOpForInput(prev_op, i); + if (!prev_prev_op || !(prev_prev_op->isa() || + prev_prev_op->isa())) { + return false; + } if (!prev_op->operand_source(i) .type() .isa()) { return false; } } - } else { - // 3. inputs must be a dense tensor type + // 4. inputs must be a dense tensor type if (!op->operand_source(i) .type() .isa()) { @@ -122,13 +129,13 @@ class ConstantFoldingPattern : public pir::RewritePattern { if (!op->result(i) || !op->result(i).type()) { continue; } - // 4. outputs must be a dense tensor type + // 5. outputs must be a dense tensor type if (!op->result(i).type().isa()) { return false; } } - // 5. maybe affect performence + // 6. maybe affect performence if (op->isa()) { auto next_ops = pir::GetUseOpsForOutput(op, 0); for (auto [next_op, _] : next_ops) { @@ -204,7 +211,7 @@ class ConstantFoldingPattern : public pir::RewritePattern { } auto constant_op = rewriter.Build( - rewriter.tensor_name_attr(output_var_name), op->result(i).type()); + output_var_name, op->result(i).type()); constant_op->set_attribute( kAttrIsPersisable, rewriter.array_attr({rewriter.bool_attr(true)})); @@ -266,6 +273,29 @@ class ConstantFoldingPattern : public pir::RewritePattern { return output_var_names; } + template + Op BuildParameterOrConstantTensorOP( + uint32_t index, + pir::Operation* op, + pir::Builder& builder, // NOLINT + pir::PatternRewriter& rewriter) const { // NOLINT + const auto& var_name = + pir::GetParameterNameFromValue(op->operand_source(index)); + auto* var = scope_->FindVar(var_name); + PADDLE_ENFORCE_NOT_NULL(var, + phi::errors::InvalidArgument( + "Persisable var [%s] not in scope.", var_name)); + auto from_op = + builder.Build(var_name, op->operand_source(index).type()); + if (op->operand_source(index).use_count() <= 1) { + deleted_vars_->push_back(var_name); + } else { + from_op->set_attribute(kAttrIsPersisable, + rewriter.array_attr({rewriter.bool_attr(true)})); + } + return from_op; + } + std::vector BuildProgramFromOperation( pir::Operation* op, pir::Program* new_program, @@ -281,94 +311,39 @@ class ConstantFoldingPattern : public pir::RewritePattern { if (prev_op->isa()) { // prepare combine op inputs std::vector combine_op_inputs; - for (uint32_t i = 0; i < prev_op->num_operands(); i++) { - auto* combine_prev_op = pir::GetDefiningOpForInput(prev_op, i); - if (combine_prev_op->isa()) { - const auto& param_name = - pir::GetParameterNameFromValue(prev_op->operand_source(i)); - auto* param_var = scope_->FindVar(param_name); - PADDLE_ENFORCE_NOT_NULL( - param_var, - phi::errors::InvalidArgument( - "Parameter var [%s] not in scope.", param_name)); - - auto parameter_op = builder.Build( - param_name, prev_op->operand_source(i).type()); - if (prev_op->operand_source(i).use_count() <= 1) { - deleted_vars_->push_back(param_name); - } else { - parameter_op->set_attribute( - kAttrIsPersisable, - rewriter.array_attr({rewriter.bool_attr(true)})); - } + for (uint32_t j = 0; j < prev_op->num_operands(); j++) { + auto* prev_prev_op = pir::GetDefiningOpForInput(prev_op, j); + if (prev_prev_op->isa()) { + auto parameter_op = + BuildParameterOrConstantTensorOP( + j, prev_op, builder, rewriter); combine_op_inputs.push_back(parameter_op->result(0)); - } else if (combine_prev_op->isa()) { - const auto& tensor_name = - pir::GetParameterNameFromValue(prev_op->operand_source(i)); - auto* tensor_var = scope_->FindVar(tensor_name); - PADDLE_ENFORCE_NOT_NULL( - tensor_var, - phi::errors::InvalidArgument("Tensor var [%s] not in scope.", - tensor_name)); - - auto constant_op = builder.Build( - rewriter.tensor_name_attr(tensor_name), - prev_op->operand_source(i).type()); - if (prev_op->operand_source(i).use_count() <= 1) { - deleted_vars_->push_back(tensor_name); - } else { - constant_op->set_attribute( - kAttrIsPersisable, - rewriter.array_attr({rewriter.bool_attr(true)})); - } + } else if (prev_prev_op->isa()) { + auto constant_op = + BuildParameterOrConstantTensorOP( + j, prev_op, builder, rewriter); combine_op_inputs.push_back(constant_op->result(0)); } else { - PADDLE_THROW(phi::errors::Fatal("Not Support!")); + PADDLE_THROW(phi::errors::Fatal( + "Not support %s before builtin.combine op!", + prev_prev_op->name())); } } auto combine_op = builder.Build(combine_op_inputs); op_inputs.push_back(combine_op->result(0)); } else if (prev_op->isa()) { - const auto& param_name = - pir::GetParameterNameFromValue(op->operand_source(i)); - auto* param_var = scope_->FindVar(param_name); - PADDLE_ENFORCE_NOT_NULL( - param_var, - phi::errors::InvalidArgument("Parameter var [%s] not in scope.", - param_name)); - - auto parameter_op = builder.Build( - param_name, op->operand_source(i).type()); - if (op->operand_source(i).use_count() <= 1) { - deleted_vars_->push_back(param_name); - } else { - parameter_op->set_attribute( - kAttrIsPersisable, - rewriter.array_attr({rewriter.bool_attr(true)})); - } + auto parameter_op = + BuildParameterOrConstantTensorOP( + i, op, builder, rewriter); op_inputs.push_back(parameter_op->result(0)); } else if (prev_op->isa()) { - const auto& tensor_name = - pir::GetParameterNameFromValue(op->operand_source(i)); - auto* tensor_var = scope_->FindVar(tensor_name); - PADDLE_ENFORCE_NOT_NULL( - tensor_var, - phi::errors::InvalidArgument("Tensor var [%s] not in scope.", - tensor_name)); - - auto constant_op = builder.Build( - rewriter.tensor_name_attr(tensor_name), - op->operand_source(i).type()); - if (op->operand_source(i).use_count() <= 1) { - deleted_vars_->push_back(tensor_name); - } else { - constant_op->set_attribute( - kAttrIsPersisable, - rewriter.array_attr({rewriter.bool_attr(true)})); - } + auto constant_op = + BuildParameterOrConstantTensorOP( + i, op, builder, rewriter); op_inputs.push_back(constant_op->result(0)); } else { - PADDLE_THROW(phi::errors::Fatal("Not Support!")); + PADDLE_THROW(phi::errors::Fatal("Not support %s before matched op!", + prev_op->name())); } } else { op_inputs.push_back( @@ -462,7 +437,7 @@ class ConstantFoldingPatternForTrain : public ConstantFoldingPattern { output_var_name)); auto constant_op = rewriter.Build( - rewriter.tensor_name_attr(output_var_name), op->result(i).type()); + output_var_name, op->result(i).type()); constant_op->set_attribute( kAttrIsPersisable, rewriter.array_attr({rewriter.bool_attr(true)})); diff --git a/paddle/pir/core/builtin_op.cc b/paddle/pir/core/builtin_op.cc index c183993d27709b..193d789b53b658 100644 --- a/paddle/pir/core/builtin_op.cc +++ b/paddle/pir/core/builtin_op.cc @@ -111,8 +111,7 @@ void ParameterOp::Build(Builder &builder, OperationArgument &argument, const std::string &name, Type type) { - argument.attributes[attributes_name[0]] = - pir::StrAttribute::get(builder.ir_context(), name); + argument.attributes[attributes_name[0]] = builder.str_attr(name); argument.output_types.emplace_back(type); PassStopGradients(argument); } @@ -151,9 +150,9 @@ void SetParameterOp::Build(Builder &builder, // NOLINT Value parameter, const std::string &name) { argument.AddInput(parameter); - argument.AddAttribute(attributes_name[0], - pir::StrAttribute::get(builder.ir_context(), name)); + argument.AddAttribute(attributes_name[0], builder.str_attr(name)); } + void SetParameterOp::VerifySig() const { VLOG(10) << "Verifying inputs, outputs and attributes for: SetParameterOp."; // Verify inputs: @@ -177,9 +176,9 @@ void ShadowOutputOp::Build(Builder &builder, // NOLINT Value parameter, const std::string &name) { argument.AddInput(parameter); - argument.AddAttribute(attributes_name[0], - pir::StrAttribute::get(builder.ir_context(), name)); + argument.AddAttribute(attributes_name[0], builder.str_attr(name)); } + void ShadowOutputOp::VerifySig() const { VLOG(10) << "Verifying inputs, outputs and attributes for: ShadowOutputOp."; // Verify inputs: @@ -203,8 +202,7 @@ void CombineOp::Build(Builder &builder, for (size_t idx = 0; idx < inputs.size(); ++idx) { inputs_type[idx] = inputs[idx].type(); } - argument.output_types.emplace_back( - pir::VectorType::get(builder.ir_context(), inputs_type)); + argument.output_types.emplace_back(builder.vec_type(inputs_type)); PassStopGradientsDefaultly(argument); } @@ -249,8 +247,7 @@ void SliceOp::Build(Builder &builder, .data()[static_cast(index)]); PassStopGradients(argument, index); - argument.AddAttribute( - "index", pir::Int32Attribute::get(pir::IrContext::Instance(), index)); + argument.AddAttribute("index", builder.int32_attr(index)); } void SliceOp::PassStopGradients(OperationArgument &argument, int index) { @@ -492,6 +489,14 @@ bool ConstantTensorOp::classof(const Operation *op) { op->attribute("value").isa(); } +void ConstantTensorOp::Build(Builder &builder, + OperationArgument &argument, + const std::string &name, + Type output_type) { + ConstantOp::Build( + builder, argument, builder.tensor_name_attr(name), output_type); +} + std::string ConstantTensorOp::tensor_name() { return value().dyn_cast().data(); } diff --git a/paddle/pir/core/builtin_op.h b/paddle/pir/core/builtin_op.h index d7c1d26c13e6e0..15f0dd62c50d53 100644 --- a/paddle/pir/core/builtin_op.h +++ b/paddle/pir/core/builtin_op.h @@ -223,6 +223,11 @@ class IR_API ConstantTensorOp : public ConstantOp { static ConstantTensorOp dyn_cast(Operation *op); static bool classof(const Operation *op); + static void Build(Builder &builder, // NOLINT + OperationArgument &argument, // NOLINT + const std::string &name, + Type output_type); + void VerifySig() const; std::string tensor_name(); From 2b6008c14a7666c5844093ba7df70dd45ece1901 Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Wed, 17 Jan 2024 10:46:39 +0000 Subject: [PATCH 4/8] fix --- test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc b/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc index 59f10f241f2cd4..7a59d808451e3e 100644 --- a/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc +++ b/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc @@ -445,10 +445,8 @@ void BuildConstantFoldingProgram(pir::Program *program, paddle::platform::DeviceContextPool::Instance().Get( paddle::platform::CPUPlace()); - auto op1 = builder.Build(builder.tensor_name_attr("a"), - dense_tensor_dtype); - auto op2 = builder.Build(builder.tensor_name_attr("b"), - dense_tensor_dtype); + auto op1 = builder.Build("a", dense_tensor_dtype); + auto op2 = builder.Build("b", dense_tensor_dtype); auto op3 = builder.Build(op1->result(0), op2->result(0)); From 0f7ec3ca6485c6eb75eae5c2acc9cfec20556902 Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Wed, 17 Jan 2024 11:40:35 +0000 Subject: [PATCH 5/8] fix ut --- test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc b/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc index 7a59d808451e3e..1c676f3a5ee2aa 100644 --- a/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc +++ b/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc @@ -583,7 +583,7 @@ TEST(constant_folding, ConstantFolding_Combine) { pm.EnableIRPrinting(); CHECK_EQ(pm.Run(&program), true); - EXPECT_EQ(program.block()->size(), 12u); + EXPECT_EQ(program.block()->size(), 2u); } void BuildMultiOutputProgram(pir::Program *program, pir::IrContext *ctx) { From aad8b4066192c37789b37687c2763cac4928d416 Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Wed, 17 Jan 2024 13:38:12 +0000 Subject: [PATCH 6/8] support drr result pattern many simple attr --- paddle/fluid/pir/drr/README.md | 6 +- paddle/fluid/pir/drr/README_cn.md | 6 +- .../pir/drr/include/drr_pattern_context.h | 43 ++++++- paddle/fluid/pir/drr/ir_operation_factory.cc | 1 + .../transforms/fusion/attention_fuse_pass.cc | 88 +++++--------- .../transforms/fusion/conv2d_add_fuse_pass.cc | 43 +++---- .../fc_elementwise_layernorm_fuse_pass.cc | 10 +- .../pir/transforms/fusion/fc_fuse_pass.cc | 28 ++--- .../fused_dot_product_attention_pass.cc | 88 ++++---------- .../fusion/fused_gemm_epilogue_pass.cc | 45 ++----- .../fused_linear_param_grad_add_pass.cc | 111 ++++++------------ .../fusion/fused_weight_only_linear_pass.cc | 31 +---- .../fusion/matmul_scale_fuse_pass.cc | 11 +- .../pir/transforms/identity_op_clean_pass.cc | 52 ++++---- test/cpp/pir/pattern_rewrite/drr_test.cc | 4 +- 15 files changed, 202 insertions(+), 365 deletions(-) diff --git a/paddle/fluid/pir/drr/README.md b/paddle/fluid/pir/drr/README.md index 9b9790538d48ac..070ca1d907b034 100644 --- a/paddle/fluid/pir/drr/README.md +++ b/paddle/fluid/pir/drr/README.md @@ -182,14 +182,10 @@ class FusedLinearPattern : public paddle::drr::DrrPatternBase { // Define ResultPattern paddle::drr::ResultPattern res = pat.ResultPattern(); // Define Constrain - const auto &act_attr = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any { - return "none"; - }); const auto &fused_gemm_epilogue = res.Op(paddle::dialect::FusedGemmEpilogueOp::name(), {{{"trans_x", pat.Attr("trans_x")}, {"trans_y", pat.Attr("trans_y")}, - {"activation", act_attr}}}); + {"activation", res.StrAttr("none")}}}); fused_gemm_epilogue( {&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("bias")}, {&res.Tensor("out")}); diff --git a/paddle/fluid/pir/drr/README_cn.md b/paddle/fluid/pir/drr/README_cn.md index 4051a5e547f315..fd8ae5904a2aeb 100644 --- a/paddle/fluid/pir/drr/README_cn.md +++ b/paddle/fluid/pir/drr/README_cn.md @@ -185,14 +185,10 @@ class FusedLinearPattern : public paddle::drr::DrrPatternBase { // 定义 Result Pattern paddle::drr::ResultPattern res = pat.ResultPattern(); // 定义 Constrain - const auto &act_attr = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any { - return "none"; - }); const auto &fused_gemm_epilogue = res.Op(paddle::dialect::FusedGemmEpilogueOp::name(), {{{"trans_x", pat.Attr("trans_x")}, {"trans_y", pat.Attr("trans_y")}, - {"activation", act_attr}}}); + {"activation", res.StrAttr("none")}}}); fused_gemm_epilogue( {&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("bias")}, {&res.Tensor("out")}); diff --git a/paddle/fluid/pir/drr/include/drr_pattern_context.h b/paddle/fluid/pir/drr/include/drr_pattern_context.h index 0539708300ac7c..d656226a627f27 100644 --- a/paddle/fluid/pir/drr/include/drr_pattern_context.h +++ b/paddle/fluid/pir/drr/include/drr_pattern_context.h @@ -15,6 +15,7 @@ #pragma once #include +#include #include #include #include @@ -280,10 +281,46 @@ class ResultPattern { return ctx_->ResultTensorPattern(Tensor::NONE_TENSOR_NAME); } - Attribute Attr(const std::string& attr_name) const { - return NormalAttribute(attr_name); + Attribute StrAttr(const std::string& value) const { + return ComputeAttr( + [&](const MatchContext& match_ctx) -> std::string { return value; }); + } + + Attribute BoolAttr(bool value) const { + return ComputeAttr( + [&](const MatchContext& match_ctx) -> bool { return value; }); + } + + Attribute Int32Attr(int32_t value) const { + return ComputeAttr( + [&](const MatchContext& match_ctx) -> int32_t { return value; }); + } + + Attribute Int64Attr(int64_t value) const { + return ComputeAttr( + [&](const MatchContext& match_ctx) -> int64_t { return value; }); } - Attribute Attr(const AttrComputeFunc& attr_compute_func) const { + + Attribute Float32Attr(float value) const { + return ComputeAttr( + [&](const MatchContext& match_ctx) -> float { return value; }); + } + + Attribute VectorInt64Attr(const std::vector& value) const { + return ComputeAttr( + [&](const MatchContext& match_ctx) -> std::vector { + return value; + }); + } + + Attribute VectorInt32Attr(const std::vector& value) const { + return ComputeAttr( + [&](const MatchContext& match_ctx) -> std::vector { + return value; + }); + } + + Attribute ComputeAttr(const AttrComputeFunc& attr_compute_func) const { return ComputeAttribute(attr_compute_func); } diff --git a/paddle/fluid/pir/drr/ir_operation_factory.cc b/paddle/fluid/pir/drr/ir_operation_factory.cc index c552550b98c2a7..50623de695380e 100644 --- a/paddle/fluid/pir/drr/ir_operation_factory.cc +++ b/paddle/fluid/pir/drr/ir_operation_factory.cc @@ -19,6 +19,7 @@ #include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/pir/drr/attr_type_uilts.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_context.h" #include "paddle/phi/core/enforce.h" #include "paddle/pir/core/builtin_op.h" #include "paddle/pir/core/operation.h" diff --git a/paddle/fluid/pir/transforms/fusion/attention_fuse_pass.cc b/paddle/fluid/pir/transforms/fusion/attention_fuse_pass.cc index 4d670d92f4aa8d..86c61bf3f14e1a 100644 --- a/paddle/fluid/pir/transforms/fusion/attention_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/attention_fuse_pass.cc @@ -170,7 +170,7 @@ class MultiHeadMatmulFuseNoBiasQKPattern : public paddle::drr::DrrPatternBase { paddle::drr::ResultPattern res = src.ResultPattern(); // W reshape. - const auto &reshape_w_shape_attr = res.Attr( + const auto &reshape_w_shape_attr = res.ComputeAttr( [](const paddle::drr::MatchContext &match_ctx) -> std::vector { auto matmul_1_in_2 = pir::GetShapeFromValue(match_ctx.Tensor("matmul_1_in_2")); @@ -195,14 +195,12 @@ class MultiHeadMatmulFuseNoBiasQKPattern : public paddle::drr::DrrPatternBase { &res.Tensor("reshape_6_out"), &res.Tensor("reshape_7_out")}, {&res.Tensor("combine_1_out")}); - const auto &concat_1_axis_attr = res.Attr( - [](const paddle::drr::MatchContext &match_ctx) -> int { return 1; }); - const auto &concat_1 = - res.Op("pd_op.concat", {{"axis", concat_1_axis_attr}}); + + const auto &concat_1 = res.Op("pd_op.concat", {{"axis", res.Int32Attr(1)}}); res.Tensor("concat_1_out") = concat_1(res.Tensor("combine_1_out")); // Bias reshape. - const auto &reshape_b_shape_attr = res.Attr( + const auto &reshape_b_shape_attr = res.ComputeAttr( [](const paddle::drr::MatchContext &match_ctx) -> std::vector { auto add_1_in_2 = pir::GetShapeFromValue(match_ctx.Tensor("add_1_in_2")); @@ -227,38 +225,26 @@ class MultiHeadMatmulFuseNoBiasQKPattern : public paddle::drr::DrrPatternBase { &res.Tensor("reshape_9_out"), &res.Tensor("reshape_10_out")}, {&res.Tensor("combine_2_out")}); - const auto &concat_2_axis_attr = res.Attr( - [](const paddle::drr::MatchContext &match_ctx) -> int { return 0; }); - const auto &concat_2 = - res.Op("pd_op.concat", {{"axis", concat_2_axis_attr}}); + + const auto &concat_2 = res.Op("pd_op.concat", {{"axis", res.Int32Attr(0)}}); res.Tensor("concat_2_out") = concat_2(res.Tensor("combine_2_out")); const auto &head_number = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> int { + res.ComputeAttr([](const paddle::drr::MatchContext &match_ctx) -> int { const auto &full_int_array_1_value = match_ctx.Attr>("full_int_array_1_value"); return full_int_array_1_value.at(2); }); - const auto &alpha = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> float { + const auto &alpha = res.ComputeAttr( + [](const paddle::drr::MatchContext &match_ctx) -> float { return match_ctx.Attr("full_1_value"); }); - const auto &multihead_matmul = - res.Op("pd_op.multihead_matmul", - {{"transpose_q", - res.Attr([](const paddle::drr::MatchContext &match_ctx) { - return false; - })}, - {"transpose_k", - res.Attr([](const paddle::drr::MatchContext &match_ctx) { - return true; - })}, - {"transpose_v", - res.Attr([](const paddle::drr::MatchContext &match_ctx) { - return false; - })}, - {"head_number", head_number}, - {"alpha", alpha}}); + const auto &multihead_matmul = res.Op("pd_op.multihead_matmul", + {{"transpose_q", res.BoolAttr(false)}, + {"transpose_k", res.BoolAttr(true)}, + {"transpose_v", res.BoolAttr(false)}, + {"head_number", head_number}, + {"alpha", alpha}}); multihead_matmul({&res.Tensor("matmul_1_in_1"), &res.Tensor("concat_1_out"), &res.Tensor("concat_2_out"), @@ -423,7 +409,7 @@ class MultiHeadMatmulFuseWithBiasQKPattern paddle::drr::ResultPattern res = src.ResultPattern(); // W reshape. - const auto &reshape_w_shape_attr = res.Attr( + const auto &reshape_w_shape_attr = res.ComputeAttr( [](const paddle::drr::MatchContext &match_ctx) -> std::vector { auto matmul_1_in_2 = pir::GetShapeFromValue(match_ctx.Tensor("matmul_1_in_2")); @@ -448,14 +434,12 @@ class MultiHeadMatmulFuseWithBiasQKPattern &res.Tensor("reshape_6_out"), &res.Tensor("reshape_7_out")}, {&res.Tensor("combine_1_out")}); - const auto &concat_1_axis_attr = res.Attr( - [](const paddle::drr::MatchContext &match_ctx) -> int { return 1; }); - const auto &concat_1 = - res.Op("pd_op.concat", {{"axis", concat_1_axis_attr}}); + + const auto &concat_1 = res.Op("pd_op.concat", {{"axis", res.Int32Attr(1)}}); res.Tensor("concat_1_out") = concat_1(res.Tensor("combine_1_out")); // Bias reshape. - const auto &reshape_b_shape_attr = res.Attr( + const auto &reshape_b_shape_attr = res.ComputeAttr( [](const paddle::drr::MatchContext &match_ctx) -> std::vector { auto add_1_in_2 = pir::GetShapeFromValue(match_ctx.Tensor("add_1_in_2")); @@ -480,38 +464,26 @@ class MultiHeadMatmulFuseWithBiasQKPattern &res.Tensor("reshape_9_out"), &res.Tensor("reshape_10_out")}, {&res.Tensor("combine_2_out")}); - const auto &concat_2_axis_attr = res.Attr( - [](const paddle::drr::MatchContext &match_ctx) -> int { return 0; }); - const auto &concat_2 = - res.Op("pd_op.concat", {{"axis", concat_2_axis_attr}}); + + const auto &concat_2 = res.Op("pd_op.concat", {{"axis", res.Int32Attr(0)}}); res.Tensor("concat_2_out") = concat_2(res.Tensor("combine_2_out")); const auto &head_number = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> int { + res.ComputeAttr([](const paddle::drr::MatchContext &match_ctx) -> int { const auto &full_int_array_1_value = match_ctx.Attr>("full_int_array_1_value"); return full_int_array_1_value.at(2); }); - const auto &alpha = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> float { + const auto &alpha = res.ComputeAttr( + [](const paddle::drr::MatchContext &match_ctx) -> float { return match_ctx.Attr("full_1_value"); }); - const auto &multihead_matmul = - res.Op("pd_op.multihead_matmul", - {{"transpose_q", - res.Attr([](const paddle::drr::MatchContext &match_ctx) { - return false; - })}, - {"transpose_k", - res.Attr([](const paddle::drr::MatchContext &match_ctx) { - return true; - })}, - {"transpose_v", - res.Attr([](const paddle::drr::MatchContext &match_ctx) { - return false; - })}, - {"head_number", head_number}, - {"alpha", alpha}}); + const auto &multihead_matmul = res.Op("pd_op.multihead_matmul", + {{"transpose_q", res.BoolAttr(false)}, + {"transpose_k", res.BoolAttr(true)}, + {"transpose_v", res.BoolAttr(false)}, + {"head_number", head_number}, + {"alpha", alpha}}); multihead_matmul({&res.Tensor("matmul_1_in_1"), &res.Tensor("concat_1_out"), &res.Tensor("concat_2_out"), diff --git a/paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.cc b/paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.cc index fbfa9c6891a55a..7605971ba59ab2 100644 --- a/paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.cc @@ -41,34 +41,21 @@ class Conv2dAddFusePattern : public paddle::drr::DrrPatternBase { paddle::drr::ResultPattern res = pat.ResultPattern(); - const auto &fused_conv2d_add_act = res.Op( - paddle::dialect::FusedConv2dAddActOp::name(), - {{ - {"strides", pat.Attr("strides")}, - {"paddings", pat.Attr("paddings")}, - {"padding_algorithm", pat.Attr("padding_algorithm")}, - {"dilations", pat.Attr("dilations")}, - {"groups", pat.Attr("groups")}, - {"data_format", pat.Attr("data_format")}, - {"activation", - res.Attr([](const paddle::drr::MatchContext &match_ctx) - -> std::string { return "identity"; })}, - {"split_channels", - res.Attr([](const paddle::drr::MatchContext &match_ctx) - -> std::vector { return {}; })}, - {"exhaustive_search", - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { - return false; - })}, - {"workspace_size_MB", - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> int { - return 32; - })}, - {"fuse_alpha", - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> float { - return 0.0f; - })}, - }}); + const auto &fused_conv2d_add_act = + res.Op(paddle::dialect::FusedConv2dAddActOp::name(), + {{ + {"strides", pat.Attr("strides")}, + {"paddings", pat.Attr("paddings")}, + {"padding_algorithm", pat.Attr("padding_algorithm")}, + {"dilations", pat.Attr("dilations")}, + {"groups", pat.Attr("groups")}, + {"data_format", pat.Attr("data_format")}, + {"activation", res.StrAttr("identity")}, + {"split_channels", res.VectorInt32Attr({})}, + {"exhaustive_search", res.BoolAttr(false)}, + {"workspace_size_MB", res.Int32Attr(32)}, + {"fuse_alpha", res.Float32Attr(0.0f)}, + }}); fused_conv2d_add_act({&res.Tensor("input"), &res.Tensor("filter"), diff --git a/paddle/fluid/pir/transforms/fusion/fc_elementwise_layernorm_fuse_pass.cc b/paddle/fluid/pir/transforms/fusion/fc_elementwise_layernorm_fuse_pass.cc index e57e9b1bef7278..09e600aa1c4586 100644 --- a/paddle/fluid/pir/transforms/fusion/fc_elementwise_layernorm_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fc_elementwise_layernorm_fuse_pass.cc @@ -65,14 +65,8 @@ class FcElementwiseLayerNormFusePattern : public paddle::drr::DrrPatternBase { paddle::drr::ResultPattern res = pat.ResultPattern(); - const auto &x_num_col_dims_attr = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any { - return 1; - }); - const auto &false_attr = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { - return false; - }); + const auto &x_num_col_dims_attr = res.Int32Attr(1); + const auto &false_attr = res.BoolAttr(false); const auto &fused_fc_elementwise_op = res.Op(paddle::dialect::FusedFcElementwiseLayernormOp::name(), diff --git a/paddle/fluid/pir/transforms/fusion/fc_fuse_pass.cc b/paddle/fluid/pir/transforms/fusion/fc_fuse_pass.cc index 18200f2e6b4e2a..843402875a098c 100644 --- a/paddle/fluid/pir/transforms/fusion/fc_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fc_fuse_pass.cc @@ -58,24 +58,18 @@ class MatmulAddPattern : public paddle::drr::DrrPatternBase { paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &in_num_col_dims_attr = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any { + res.ComputeAttr([](const paddle::drr::MatchContext &match_ctx) -> int { auto x_dims = pir::GetShapeFromValue(match_ctx.Tensor("x")); return x_dims.size() - 1; }); - const auto &false_attr = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { - return false; - }); - - const auto &fc = - res.Op(paddle::dialect::FcOp::name(), - {{ - {"in_num_col_dims", in_num_col_dims_attr}, - {"activation_type", - res.Attr([](const paddle::drr::MatchContext &match_ctx) - -> std::string { return ""; })}, - {"padding_weights", false_attr}, - }}); + const auto &false_attr = res.BoolAttr(false); + + const auto &fc = res.Op(paddle::dialect::FcOp::name(), + {{ + {"in_num_col_dims", in_num_col_dims_attr}, + {"activation_type", res.StrAttr("")}, + {"padding_weights", false_attr}, + }}); fc({&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("y")}, {&res.Tensor("add_out")}); } @@ -110,9 +104,7 @@ class FcWithReluPattern : public paddle::drr::DrrPatternBase { res.Op(paddle::dialect::FcOp::name(), {{ {"in_num_col_dims", pat.Attr("in_num_col_dims")}, - {"activation_type", - res.Attr([](const paddle::drr::MatchContext &match_ctx) - -> std::string { return "relu"; })}, + {"activation_type", res.StrAttr("relu")}, {"padding_weights", pat.Attr("padding_weights")}, }}); fc_with_relu({&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("y")}, diff --git a/paddle/fluid/pir/transforms/fusion/fused_dot_product_attention_pass.cc b/paddle/fluid/pir/transforms/fusion/fused_dot_product_attention_pass.cc index 0b5737ecf69d6d..b6379cb60473ea 100644 --- a/paddle/fluid/pir/transforms/fusion/fused_dot_product_attention_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fused_dot_product_attention_pass.cc @@ -105,29 +105,17 @@ class FusedDotProductAttentionPattern : public paddle::drr::DrrPatternBase { // Result pattern paddle::drr::ResultPattern res = src.ResultPattern(); - const auto &scaling_factor = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> float { + const auto &scaling_factor = res.ComputeAttr( + [](const paddle::drr::MatchContext &match_ctx) -> float { return match_ctx.Attr("q_scale_value"); }); - const auto &dropout_prob = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> float { - return static_cast(0.0); - }); - const auto &is_training = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { - return true; - }); - const auto &is_causal_masking = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { - return false; - }); const auto &dot_product_attention = res.Op(paddle::dialect::FusedDotProductAttentionOp::name(), {{{"scaling_factor", scaling_factor}, - {"dropout_probability", dropout_prob}, - {"is_training", is_training}, - {"is_causal_masking", is_causal_masking}}}); + {"dropout_probability", res.Float32Attr(0.0)}, + {"is_training", res.BoolAttr(true)}, + {"is_causal_masking", res.BoolAttr(false)}}}); dot_product_attention({&res.Tensor("q"), &res.Tensor("k"), @@ -270,29 +258,17 @@ class FusedDotProductAttentionGradPattern : public paddle::drr::DrrPatternBase { // Result pattern paddle::drr::ResultPattern res = src.ResultPattern(); - const auto &scaling_factor = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> float { + const auto &scaling_factor = res.ComputeAttr( + [](const paddle::drr::MatchContext &match_ctx) -> float { return match_ctx.Attr("q_scale_value"); }); - const auto &dropout_prob = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> float { - return static_cast(0.0); - }); - const auto &is_training = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { - return true; - }); - const auto &is_causal_masking = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { - return false; - }); const auto &dot_product_attention = res.Op(paddle::dialect::FusedDotProductAttentionOp::name(), {{{"scaling_factor", scaling_factor}, - {"dropout_probability", dropout_prob}, - {"is_training", is_training}, - {"is_causal_masking", is_causal_masking}}}); + {"dropout_probability", res.Float32Attr(0.0)}, + {"is_training", res.BoolAttr(true)}, + {"is_causal_masking", res.BoolAttr(false)}}}); dot_product_attention({&res.Tensor("q"), &res.Tensor("k"), @@ -304,8 +280,8 @@ class FusedDotProductAttentionGradPattern : public paddle::drr::DrrPatternBase { const auto &dot_product_attention_grad = res.Op(paddle::dialect::FusedDotProductAttentionGradOp::name(), {{{"scaling_factor", scaling_factor}, - {"dropout_probability", dropout_prob}, - {"is_causal_masking", is_causal_masking}}}); + {"dropout_probability", res.Float32Attr(0.0)}, + {"is_causal_masking", res.BoolAttr(false)}}}); dot_product_attention_grad( {&res.Tensor("q"), &res.Tensor("k"), @@ -415,29 +391,17 @@ class FusedDotProductAttentionWithDropoutPattern // Result pattern paddle::drr::ResultPattern res = src.ResultPattern(); - const auto &scaling_factor = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> float { + const auto &scaling_factor = res.ComputeAttr( + [](const paddle::drr::MatchContext &match_ctx) -> float { return match_ctx.Attr("q_scale_value"); }); - const auto &dropout_prob = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> float { - return static_cast(0.0); - }); - const auto &is_training = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { - return true; - }); - const auto &is_causal_masking = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { - return false; - }); const auto &dot_product_attention = res.Op(paddle::dialect::FusedDotProductAttentionOp::name(), {{{"scaling_factor", scaling_factor}, - {"dropout_probability", src.Attr("dropout_prob")}, - {"is_training", is_training}, - {"is_causal_masking", is_causal_masking}}}); + {"dropout_probability", res.Float32Attr(0.0)}, + {"is_training", res.BoolAttr(true)}, + {"is_causal_masking", res.BoolAttr(false)}}}); dot_product_attention({&res.Tensor("q"), &res.Tensor("k"), @@ -595,25 +559,17 @@ class FusedDotProductAttentionGradWithDropoutPattern // Result pattern paddle::drr::ResultPattern res = src.ResultPattern(); - const auto &scaling_factor = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> float { + const auto &scaling_factor = res.ComputeAttr( + [](const paddle::drr::MatchContext &match_ctx) -> float { return match_ctx.Attr("q_scale_value"); }); - const auto &is_training = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { - return true; - }); - const auto &is_causal_masking = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { - return false; - }); const auto &dot_product_attention = res.Op(paddle::dialect::FusedDotProductAttentionOp::name(), {{{"scaling_factor", scaling_factor}, {"dropout_probability", src.Attr("dropout_prob")}, - {"is_training", is_training}, - {"is_causal_masking", is_causal_masking}}}); + {"is_training", res.BoolAttr(true)}, + {"is_causal_masking", res.BoolAttr(false)}}}); dot_product_attention({&res.Tensor("q"), &res.Tensor("k"), @@ -626,7 +582,7 @@ class FusedDotProductAttentionGradWithDropoutPattern res.Op(paddle::dialect::FusedDotProductAttentionGradOp::name(), {{{"scaling_factor", scaling_factor}, {"dropout_probability", src.Attr("dropout_prob")}, - {"is_causal_masking", is_causal_masking}}}); + {"is_causal_masking", res.BoolAttr(false)}}}); dot_product_attention_grad( {&res.Tensor("q"), &res.Tensor("k"), diff --git a/paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.cc b/paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.cc index 6a39c015893e32..3aa0724e3d7f08 100644 --- a/paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.cc @@ -44,15 +44,11 @@ class FusedLinearPattern : public paddle::drr::DrrPatternBase { }); paddle::drr::ResultPattern res = pat.ResultPattern(); - const auto &act_attr = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any { - return "none"; - }); const auto &fused_gemm_epilogue = res.Op(paddle::dialect::FusedGemmEpilogueOp::name(), {{{"trans_x", pat.Attr("trans_x")}, {"trans_y", pat.Attr("trans_y")}, - {"activation", act_attr}}}); + {"activation", res.StrAttr("none")}}}); fused_gemm_epilogue( {&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("bias")}, {&res.Tensor("out")}); @@ -90,20 +86,17 @@ class FusedLinearGradPattern : public paddle::drr::DrrPatternBase { }); paddle::drr::ResultPattern res = pat.ResultPattern(); - const auto &act_attr = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any { - return "none"; - }); + const auto &fused_gemm_epilogue = res.Op(paddle::dialect::FusedGemmEpilogueOp::name(), {{{"trans_x", pat.Attr("trans_x")}, {"trans_y", pat.Attr("trans_y")}, - {"activation", act_attr}}}); + {"activation", res.StrAttr("none")}}}); const auto &fused_gemm_epilogue_grad = res.Op(paddle::dialect::FusedGemmEpilogueGradOp::name(), {{{"trans_x", pat.Attr("trans_x")}, {"trans_y", pat.Attr("trans_y")}, - {"activation_grad", act_attr}}}); + {"activation_grad", res.StrAttr("none")}}}); fused_gemm_epilogue( {&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("bias")}, {&res.Tensor("out")}); @@ -142,15 +135,11 @@ class FusedLinearGeluPattern : public paddle::drr::DrrPatternBase { // Result pattern paddle::drr::ResultPattern res = pat.ResultPattern(); - const auto &act_attr = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any { - return "gelu"; - }); const auto &fused_gemm_epilogue_gelu = res.Op(paddle::dialect::FusedGemmEpilogueOp::name(), {{{"trans_x", pat.Attr("trans_x")}, {"trans_y", pat.Attr("trans_y")}, - {"activation", act_attr}}}); + {"activation", res.StrAttr("gelu")}}}); fused_gemm_epilogue_gelu( {&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("bias")}, {&res.Tensor("out"), &res.Tensor("reserve_space")}); @@ -182,15 +171,11 @@ class FusedLinearReluPattern : public paddle::drr::DrrPatternBase { // Result pattern paddle::drr::ResultPattern res = pat.ResultPattern(); - const auto &act_attr = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any { - return "relu"; - }); const auto &fused_gemm_epilogue_relu = res.Op(paddle::dialect::FusedGemmEpilogueOp::name(), {{{"trans_x", pat.Attr("trans_x")}, {"trans_y", pat.Attr("trans_y")}, - {"activation", act_attr}}}); + {"activation", res.StrAttr("relu")}}}); fused_gemm_epilogue_relu( {&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("bias")}, {&res.Tensor("out"), &res.Tensor("reserve_space")}); @@ -235,24 +220,16 @@ class FusedLinearGeluGradPattern : public paddle::drr::DrrPatternBase { }); paddle::drr::ResultPattern res = pat.ResultPattern(); - const auto &act_attr = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any { - return "gelu"; - }); const auto &fused_gemm_epilogue_new = res.Op(paddle::dialect::FusedGemmEpilogueOp::name(), {{{"trans_x", pat.Attr("trans_x1")}, {"trans_y", pat.Attr("trans_y1")}, - {"activation", act_attr}}}); - const auto &act_grad_attr = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any { - return "gelu_grad"; - }); + {"activation", res.StrAttr("gelu")}}}); const auto &fused_gemm_epilogue_grad_new = res.Op(paddle::dialect::FusedGemmEpilogueGradOp::name(), {{{"trans_x", pat.Attr("trans_x2")}, {"trans_y", pat.Attr("trans_y2")}, - {"activation_grad", act_grad_attr}}}); + {"activation_grad", res.StrAttr("gelu_grad")}}}); fused_gemm_epilogue_new( {&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("bias")}, {&res.Tensor("out"), &res.Tensor("reserve_space2")}); @@ -315,15 +292,11 @@ class FusedLinearReluGradPattern : public paddle::drr::DrrPatternBase { }); paddle::drr::ResultPattern res = pat.ResultPattern(); - const auto &act_grad_attr = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any { - return "relu_grad"; - }); const auto &res_fused_gemm_epilogue_grad1 = res.Op(paddle::dialect::FusedGemmEpilogueGradOp::name(), {{{"trans_x", pat.Attr("trans_x3")}, {"trans_y", pat.Attr("trans_y3")}, - {"activation_grad", act_grad_attr}}}); + {"activation_grad", res.StrAttr("relu_grad")}}}); res_fused_gemm_epilogue_grad1({&res.Tensor("x1"), &res.Tensor("w1"), diff --git a/paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.cc b/paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.cc index 1453426cc8df6f..f3c90082225b8a 100644 --- a/paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.cc @@ -66,26 +66,18 @@ class FusedMatmulAddGradAddPattern : public paddle::drr::DrrPatternBase { paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &muti_precision_attr = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { + res.ComputeAttr([](const paddle::drr::MatchContext &match_ctx) -> bool { return !(pir::GetDataTypeFromValue(match_ctx.Tensor("dweight")) == pir::GetDataTypeFromValue(match_ctx.Tensor("weight_grad"))); }); - const auto &true_attr = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { - return true; - }); - const auto &false_attr = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { - return false; - }); - - const auto &matmul = - res.Op(paddle::dialect::MatmulOp::name(), - {{"transpose_x", false_attr}, {"transpose_y", true_attr}}); - const auto &fused_linear_param_grad_add = res.Op( - paddle::dialect::FusedLinearParamGradAddOp::name(), - {{{"multi_precision", muti_precision_attr}, {"has_bias", true_attr}}}); + const auto &matmul = res.Op(paddle::dialect::MatmulOp::name(), + {{"transpose_x", res.BoolAttr(false)}, + {"transpose_y", res.BoolAttr(true)}}); + const auto &fused_linear_param_grad_add = + res.Op(paddle::dialect::FusedLinearParamGradAddOp::name(), + {{{"multi_precision", muti_precision_attr}, + {"has_bias", res.BoolAttr(true)}}}); matmul({&res.Tensor("fwd_add_out_grad"), &res.Tensor("weight")}, {&res.Tensor("x_grad")}); @@ -128,26 +120,18 @@ class FusedMatmulGradAddPattern : public paddle::drr::DrrPatternBase { paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &muti_precision_attr = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { + res.ComputeAttr([](const paddle::drr::MatchContext &match_ctx) -> bool { return !(pir::GetDataTypeFromValue(match_ctx.Tensor("dweight")) == pir::GetDataTypeFromValue(match_ctx.Tensor("weight_grad"))); }); - const auto &true_attr = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { - return true; - }); - const auto &false_attr = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { - return false; - }); - - const auto &matmul = - res.Op(paddle::dialect::MatmulOp::name(), - {{"transpose_x", false_attr}, {"transpose_y", true_attr}}); - const auto &fused_linear_param_grad_add = res.Op( - paddle::dialect::FusedLinearParamGradAddOp::name(), - {{{"multi_precision", muti_precision_attr}, {"has_bias", false_attr}}}); + const auto &matmul = res.Op(paddle::dialect::MatmulOp::name(), + {{"transpose_x", res.BoolAttr(false)}, + {"transpose_y", res.BoolAttr(true)}}); + const auto &fused_linear_param_grad_add = + res.Op(paddle::dialect::FusedLinearParamGradAddOp::name(), + {{{"multi_precision", muti_precision_attr}, + {"has_bias", res.BoolAttr(false)}}}); matmul({&res.Tensor("out_grad"), &res.Tensor("weight")}, {&res.Tensor("x_grad")}); @@ -186,23 +170,15 @@ class FusedMatmulAddaPattern : public paddle::drr::DrrPatternBase { paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &muti_precision_attr = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { + res.ComputeAttr([](const paddle::drr::MatchContext &match_ctx) -> bool { return !(pir::GetDataTypeFromValue(match_ctx.Tensor("dweight")) == pir::GetDataTypeFromValue(match_ctx.Tensor("weight_grad"))); }); - const auto &true_attr = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { - return true; - }); - const auto &false_attr = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { - return false; - }); - - const auto &fused_linear_param_grad_add = res.Op( - paddle::dialect::FusedLinearParamGradAddOp::name(), - {{{"multi_precision", muti_precision_attr}, {"has_bias", false_attr}}}); + const auto &fused_linear_param_grad_add = + res.Op(paddle::dialect::FusedLinearParamGradAddOp::name(), + {{{"multi_precision", muti_precision_attr}, + {"has_bias", res.BoolAttr(false)}}}); fused_linear_param_grad_add( {&res.Tensor("x"), &res.Tensor("out_grad"), @@ -238,23 +214,15 @@ class FusedMatmulAddbPattern : public paddle::drr::DrrPatternBase { paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &muti_precision_attr = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { + res.ComputeAttr([](const paddle::drr::MatchContext &match_ctx) -> bool { return !(pir::GetDataTypeFromValue(match_ctx.Tensor("dweight")) == pir::GetDataTypeFromValue(match_ctx.Tensor("weight_grad"))); }); - const auto &true_attr = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { - return true; - }); - const auto &false_attr = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { - return false; - }); - - const auto &fused_linear_param_grad_add = res.Op( - paddle::dialect::FusedLinearParamGradAddOp::name(), - {{{"multi_precision", muti_precision_attr}, {"has_bias", false_attr}}}); + const auto &fused_linear_param_grad_add = + res.Op(paddle::dialect::FusedLinearParamGradAddOp::name(), + {{{"multi_precision", muti_precision_attr}, + {"has_bias", res.BoolAttr(false)}}}); fused_linear_param_grad_add( {&res.Tensor("x"), &res.Tensor("out_grad"), @@ -304,17 +272,15 @@ class FusedMatmulAddGradAddaPattern : public paddle::drr::DrrPatternBase { paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &muti_precision_attr = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { + res.ComputeAttr([](const paddle::drr::MatchContext &match_ctx) -> bool { return !(pir::GetDataTypeFromValue(match_ctx.Tensor("dweight")) == pir::GetDataTypeFromValue(match_ctx.Tensor("weight_grad"))); }); - const auto &true_attr = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { - return true; - }); - const auto &fused_linear_param_grad_add = res.Op( - paddle::dialect::FusedLinearParamGradAddOp::name(), - {{{"multi_precision", muti_precision_attr}, {"has_bias", true_attr}}}); + + const auto &fused_linear_param_grad_add = + res.Op(paddle::dialect::FusedLinearParamGradAddOp::name(), + {{{"multi_precision", muti_precision_attr}, + {"has_bias", res.BoolAttr(true)}}}); fused_linear_param_grad_add( {&res.Tensor("x"), &res.Tensor("dadd_out"), @@ -364,17 +330,14 @@ class FusedMatmulAddGradAddbPattern : public paddle::drr::DrrPatternBase { paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &muti_precision_attr = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { + res.ComputeAttr([](const paddle::drr::MatchContext &match_ctx) -> bool { return !(pir::GetDataTypeFromValue(match_ctx.Tensor("dweight")) == pir::GetDataTypeFromValue(match_ctx.Tensor("weight_grad"))); }); - const auto &true_attr = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { - return true; - }); - const auto &fused_linear_param_grad_add = res.Op( - paddle::dialect::FusedLinearParamGradAddOp::name(), - {{{"multi_precision", muti_precision_attr}, {"has_bias", true_attr}}}); + const auto &fused_linear_param_grad_add = + res.Op(paddle::dialect::FusedLinearParamGradAddOp::name(), + {{{"multi_precision", muti_precision_attr}, + {"has_bias", res.BoolAttr(true)}}}); fused_linear_param_grad_add( {&res.Tensor("x"), &res.Tensor("dadd_out"), diff --git a/paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.cc b/paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.cc index df61b1eb25ba27..7a017795d29115 100644 --- a/paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.cc @@ -88,39 +88,20 @@ class FusedWeightOnlyLinearPattern : public paddle::drr::DrrPatternBase { // paddle::drr::ResultPattern res = src.ResultPattern(); - // quantize weight - const auto &weight_only_int8_attr = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any { - return "weight_only_int8"; - }); - - const auto &arch_attr = - res.Attr([&](const paddle::drr::MatchContext &match_ctx) -> int { - return getSMVersion(); - }); - - const auto &group_size_attr = res.Attr( - [](const paddle::drr::MatchContext &match_ctx) -> int { return -1; }); - const auto &weight_quantize = res.Op(paddle::dialect::WeightQuantizeOp::name(), - {{"algo", weight_only_int8_attr}, - {"arch", arch_attr}, - {"group_size", group_size_attr}}); + {{"algo", res.StrAttr("weight_only_int8")}, + {"arch", res.Int32Attr(getSMVersion())}, + {"group_size", res.Int32Attr(-1)}}); weight_quantize({&res.Tensor("w")}, {&res.Tensor("quanted_weight_tensor"), &res.Tensor("weight_scale_tensor")}); - const auto &weight_dtype_attr = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any { - return "int8"; - }); - const auto &weight_only_linear = res.Op(paddle::dialect::WeightOnlyLinearOp::name(), - {{"weight_dtype", weight_dtype_attr}, - {"arch", arch_attr}, - {"group_size", group_size_attr}}); + {{"weight_dtype", res.StrAttr("int8")}, + {"arch", res.Int32Attr(getSMVersion())}, + {"group_size", res.Int32Attr(-1)}}); weight_only_linear({&res.Tensor("x"), &res.Tensor("quanted_weight_tensor"), &res.Tensor("bias"), diff --git a/paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.cc b/paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.cc index cabd7a7274cb70..a5e36ec2293d96 100644 --- a/paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.cc @@ -55,13 +55,10 @@ class MatmulScaleFusePattern : public paddle::drr::DrrPatternBase { {"value", pat.Attr("value")}, {"dtype", pat.Attr("dtype")}, {"place", pat.Attr("place")}}); - const auto &scale_op_res = res.Op( - paddle::dialect::ScaleOp::name(), - {{"bias", - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> float { - return 0.0; - })}, - {"bias_after_scale", pat.Attr("bias_after_scale")}}); + const auto &scale_op_res = + res.Op(paddle::dialect::ScaleOp::name(), + {{"bias", res.Float32Attr(0.0)}, + {"bias_after_scale", pat.Attr("bias_after_scale")}}); const auto &matmul_op_res = res.Op(paddle::dialect::MatmulOp::name(), {{"transpose_x", pat.Attr("transpose_x")}, diff --git a/paddle/fluid/pir/transforms/identity_op_clean_pass.cc b/paddle/fluid/pir/transforms/identity_op_clean_pass.cc index 52287d0f57d5b3..26d09d5f249a6f 100644 --- a/paddle/fluid/pir/transforms/identity_op_clean_pass.cc +++ b/paddle/fluid/pir/transforms/identity_op_clean_pass.cc @@ -81,8 +81,8 @@ class RemoveRedundentScalePattern : public paddle::drr::DrrPatternBase { paddle::drr::ResultPattern res = pat.ResultPattern(); - const auto &bais_res = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> float { + const auto &bais_attr = res.ComputeAttr( + [](const paddle::drr::MatchContext &match_ctx) -> float { float res_bias_1 = 0.f; float res_bias_2 = 0.f; if (match_ctx.Attr("bias_after_scale_1")) { @@ -100,8 +100,8 @@ class RemoveRedundentScalePattern : public paddle::drr::DrrPatternBase { } return res_bias_2; }); - const auto &res_scale_input = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> float { + const auto &res_scale_input = res.ComputeAttr( + [](const paddle::drr::MatchContext &match_ctx) -> float { return match_ctx.Attr("value_1") * match_ctx.Attr("value_2"); }); @@ -111,13 +111,9 @@ class RemoveRedundentScalePattern : public paddle::drr::DrrPatternBase { {"value", res_scale_input}, {"dtype", pat.Attr("dtype_1")}, {"place", pat.Attr("place_1")}}); - const auto &scale_op_res = res.Op( - "pd_op.scale", - {{"bias", bais_res}, - {"bias_after_scale", - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { - return true; - })}}); + const auto &scale_op_res = + res.Op("pd_op.scale", + {{"bias", bais_attr}, {"bias_after_scale", res.BoolAttr(true)}}); scale_op_res({&res.Tensor("x"), &full_op_res()}, {&res.Tensor("scale_2_out")}); } @@ -216,33 +212,29 @@ class ReplaceDropoutWithScalePattern : public paddle::drr::DrrPatternBase { auto res = pat.ResultPattern(); - const auto &res_scale_input = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> float { + const auto &res_scale_input = res.ComputeAttr( + [](const paddle::drr::MatchContext &match_ctx) -> float { return 1.f - match_ctx.Attr("p"); }); const auto &full_op_res = res.Op( paddle::dialect::FullOp::name(), {{"shape", - res.Attr([](const paddle::drr::MatchContext &match_ctx) - -> phi::IntArray { return {1}; })}, + res.ComputeAttr([](const paddle::drr::MatchContext &match_ctx) + -> phi::IntArray { return {1}; })}, {"value", res_scale_input}, {"dtype", - res.Attr([](const paddle::drr::MatchContext &match_ctx) - -> phi::DataType { return phi::DataType::FLOAT32; })}, + res.ComputeAttr( + [](const paddle::drr::MatchContext &match_ctx) -> phi::DataType { + return phi::DataType::FLOAT32; + })}, {"place", - res.Attr([](const paddle::drr::MatchContext &match_ctx) - -> phi::Place { return phi::CPUPlace{}; })}}); - const auto &scale_op_res = res.Op( - "pd_op.scale", - {{"bias", - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> float { - return 0; - })}, - {"bias_after_scale", - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { - return true; - })}}); + res.ComputeAttr([](const paddle::drr::MatchContext &match_ctx) + -> phi::Place { return phi::CPUPlace{}; })}}); + const auto &scale_op_res = + res.Op("pd_op.scale", + {{"bias", res.Float32Attr(0)}, + {"bias_after_scale", res.BoolAttr(true)}}); scale_op_res({&res.Tensor("dropout_in"), &full_op_res()}, {&res.Tensor("dropout_out")}); } @@ -262,7 +254,7 @@ class RemoveRedundentTransposePattern : public paddle::drr::DrrPatternBase { pat.Tensor("ret") = transpose2(transpose1(pat.Tensor("arg_transpose"))); paddle::drr::ResultPattern res = pat.ResultPattern(); - const auto &new_perm_attr = res.Attr( + const auto &new_perm_attr = res.ComputeAttr( [](const paddle::drr::MatchContext &match_ctx) -> std::vector { const auto &perm1 = match_ctx.Attr>("perm_1"); const auto &perm2 = match_ctx.Attr>("perm_2"); diff --git a/test/cpp/pir/pattern_rewrite/drr_test.cc b/test/cpp/pir/pattern_rewrite/drr_test.cc index 6efe87d8ca70c4..735473bf2547a9 100644 --- a/test/cpp/pir/pattern_rewrite/drr_test.cc +++ b/test/cpp/pir/pattern_rewrite/drr_test.cc @@ -65,7 +65,7 @@ class FoldExpandToConstantPattern : public paddle::drr::DrrPatternBase { // Result patterns paddle::drr::ResultPattern res = pat.ResultPattern(); - const auto &new_perm_attr = res.Attr( + const auto &new_perm_attr = res.ComputeAttr( [](const paddle::drr::MatchContext &match_ctx) -> phi::IntArray { auto shape = match_ctx.Attr>("expand_shape_value"); @@ -95,7 +95,7 @@ class RemoveRedundentTransposePattern : public paddle::drr::DrrPatternBase { pat.Tensor("ret") = transpose2(transpose1(pat.Tensor("arg_transpose"))); paddle::drr::ResultPattern res = pat.ResultPattern(); - const auto &new_perm_attr = res.Attr( + const auto &new_perm_attr = res.ComputeAttr( [](const paddle::drr::MatchContext &match_ctx) -> std::vector { const auto &perm1 = match_ctx.Attr>("perm_1"); const auto &perm2 = match_ctx.Attr>("perm_2"); From 0c35e6482ee7708314d3f01e3da694ca6aee4353 Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Wed, 17 Jan 2024 13:41:59 +0000 Subject: [PATCH 7/8] rename attention_fuse_pass to multihead_matmul_fuse_pass --- paddle/fluid/inference/api/analysis_predictor.cc | 4 ++-- ...n_fuse_pass.cc => multihead_matmul_fuse_pass.cc} | 13 +++++++------ ...ion_fuse_pass.h => multihead_matmul_fuse_pass.h} | 2 +- paddle/fluid/pybind/pir.cc | 4 ++-- .../pir/pattern_rewrite/drr_attention_fuse_test.cc | 4 ++-- 5 files changed, 14 insertions(+), 13 deletions(-) rename paddle/fluid/pir/transforms/fusion/{attention_fuse_pass.cc => multihead_matmul_fuse_pass.cc} (98%) rename paddle/fluid/pir/transforms/fusion/{attention_fuse_pass.h => multihead_matmul_fuse_pass.h} (92%) diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index a74d811002bee2..72cfd033255c4d 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -107,13 +107,13 @@ #include "paddle/fluid/ir_adaptor/translator/translate.h" #include "paddle/fluid/pir/transforms/constant_folding_pass.h" #include "paddle/fluid/pir/transforms/dead_code_elimination_pass.h" -#include "paddle/fluid/pir/transforms/fusion/attention_fuse_pass.h" #include "paddle/fluid/pir/transforms/fusion/conv2d_add_act_fuse_pass.h" #include "paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.h" #include "paddle/fluid/pir/transforms/fusion/conv2d_bn_fuse_pass.h" #include "paddle/fluid/pir/transforms/fusion/fc_elementwise_layernorm_fuse_pass.h" #include "paddle/fluid/pir/transforms/fusion/fc_fuse_pass.h" #include "paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.h" +#include "paddle/fluid/pir/transforms/fusion/multihead_matmul_fuse_pass.h" #include "paddle/fluid/pir/transforms/identity_op_clean_pass.h" #include "paddle/fluid/pir/transforms/inplace_pass.h" #include "paddle/fluid/pir/transforms/params_sync_among_devices_pass.h" @@ -805,7 +805,7 @@ bool AnalysisPredictor::PrepareExecutor() { gpu_pm.AddPass(::pir::CreateConv2dBnFusePass()); gpu_pm.AddPass(::pir::CreateConv2dAddActFusePass()); gpu_pm.AddPass(::pir::CreateConv2dAddFusePass()); - gpu_pm.AddPass(::pir::CreateAttentionFusePass()); + gpu_pm.AddPass(::pir::CreateMultiHeadMatmulFusePass()); gpu_pm.AddPass(::pir::CreateFcFusePass()); gpu_pm.AddPass(::pir::CreateFcElementwiseLayerNormFusePass()); gpu_pm.AddPass(::pir::CreateMatmulScaleFusePass()); diff --git a/paddle/fluid/pir/transforms/fusion/attention_fuse_pass.cc b/paddle/fluid/pir/transforms/fusion/multihead_matmul_fuse_pass.cc similarity index 98% rename from paddle/fluid/pir/transforms/fusion/attention_fuse_pass.cc rename to paddle/fluid/pir/transforms/fusion/multihead_matmul_fuse_pass.cc index 86c61bf3f14e1a..94795d88db10a2 100644 --- a/paddle/fluid/pir/transforms/fusion/attention_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/multihead_matmul_fuse_pass.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/pir/transforms/fusion/attention_fuse_pass.h" +#include "paddle/fluid/pir/transforms/fusion/multihead_matmul_fuse_pass.h" #include "paddle/fluid/pir/drr/include/drr_pattern_base.h" #include "paddle/fluid/pir/transforms/transform_general_functions.h" @@ -496,9 +496,10 @@ class MultiHeadMatmulFuseWithBiasQKPattern } }; -class AttentionFusePass : public pir::PatternRewritePass { +class MultiHeadMatmulFusePass : public pir::PatternRewritePass { public: - AttentionFusePass() : pir::PatternRewritePass("attention_fuse_pass", 2) {} + MultiHeadMatmulFusePass() + : pir::PatternRewritePass("multihead_matmul_fuse_pass", 2) {} pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override { pir::RewritePatternSet ps(context); @@ -513,9 +514,9 @@ class AttentionFusePass : public pir::PatternRewritePass { } // namespace namespace pir { -std::unique_ptr CreateAttentionFusePass() { - return std::make_unique(); +std::unique_ptr CreateMultiHeadMatmulFusePass() { + return std::make_unique(); } } // namespace pir -REGISTER_IR_PASS(attention_fuse_pass, AttentionFusePass); +REGISTER_IR_PASS(multihead_matmul_fuse_pass, MultiHeadMatmulFusePass); diff --git a/paddle/fluid/pir/transforms/fusion/attention_fuse_pass.h b/paddle/fluid/pir/transforms/fusion/multihead_matmul_fuse_pass.h similarity index 92% rename from paddle/fluid/pir/transforms/fusion/attention_fuse_pass.h rename to paddle/fluid/pir/transforms/fusion/multihead_matmul_fuse_pass.h index 0c0d2e84952ca4..82486c40ee1ace 100644 --- a/paddle/fluid/pir/transforms/fusion/attention_fuse_pass.h +++ b/paddle/fluid/pir/transforms/fusion/multihead_matmul_fuse_pass.h @@ -21,6 +21,6 @@ namespace pir { class Pass; -IR_API std::unique_ptr CreateAttentionFusePass(); +IR_API std::unique_ptr CreateMultiHeadMatmulFusePass(); } // namespace pir diff --git a/paddle/fluid/pybind/pir.cc b/paddle/fluid/pybind/pir.cc index 103c5c34df9d55..33277e7919f8c6 100644 --- a/paddle/fluid/pybind/pir.cc +++ b/paddle/fluid/pybind/pir.cc @@ -43,7 +43,6 @@ #include "paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.h" #include "paddle/fluid/pir/dialect/operator/utils/utils.h" #include "paddle/fluid/pir/transforms/dead_code_elimination_pass.h" -#include "paddle/fluid/pir/transforms/fusion/attention_fuse_pass.h" #include "paddle/fluid/pir/transforms/fusion/conv2d_add_act_fuse_pass.h" #include "paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.h" #include "paddle/fluid/pir/transforms/fusion/conv2d_bn_fuse_pass.h" @@ -55,6 +54,7 @@ #include "paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.h" #include "paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.h" #include "paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.h" +#include "paddle/fluid/pir/transforms/fusion/multihead_matmul_fuse_pass.h" #include "paddle/fluid/pir/transforms/identity_op_clean_pass.h" #include "paddle/fluid/pir/transforms/inplace_pass.h" #include "paddle/fluid/pir/transforms/replace_fetch_with_shadow_output_pass.h" @@ -112,7 +112,7 @@ using pir::Value; using pybind11::return_value_policy; USE_PIR_PASS(dead_code_elimination_pass); -USE_PIR_PASS(attention_fuse_pass); +USE_PIR_PASS(multihead_matmul_fuse_pass); USE_PIR_PASS(fused_gemm_epilogue_pass); USE_PIR_PASS(fused_dropout_add_pass); USE_PIR_PASS(fused_weight_only_linear_pass); diff --git a/test/cpp/pir/pattern_rewrite/drr_attention_fuse_test.cc b/test/cpp/pir/pattern_rewrite/drr_attention_fuse_test.cc index 4361fd03a306f3..373599479434f7 100644 --- a/test/cpp/pir/pattern_rewrite/drr_attention_fuse_test.cc +++ b/test/cpp/pir/pattern_rewrite/drr_attention_fuse_test.cc @@ -22,7 +22,7 @@ #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/pir/transforms/constant_folding_pass.h" #include "paddle/fluid/pir/transforms/dead_code_elimination_pass.h" -#include "paddle/fluid/pir/transforms/fusion/attention_fuse_pass.h" +#include "paddle/fluid/pir/transforms/fusion/multihead_matmul_fuse_pass.h" #include "paddle/phi/common/place.h" #include "paddle/pir/core/builtin_dialect.h" @@ -146,7 +146,7 @@ TEST(DrrTest, AttentionFuse) { EXPECT_EQ(program.block()->size(), 33u); pir::PassManager pm(ctx); - pm.AddPass(pir::CreateAttentionFusePass()); + pm.AddPass(pir::CreateMultiHeadMatmulFusePass()); std::unique_ptr constant_folding_pass = pir::CreateConstantFoldingPass(); constant_folding_pass->Set(pir::kPlaceAttr, new phi::Place{phi::GPUPlace{}}); From 58a045ae15bd6fffced29fa14808c9a37319fad7 Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Thu, 18 Jan 2024 05:38:45 +0000 Subject: [PATCH 8/8] fix ut --- paddle/fluid/pir/drr/include/drr_pattern_context.h | 14 +++++++------- .../pir/pattern_rewrite/drr_attention_fuse_test.cc | 4 ++++ 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/pir/drr/include/drr_pattern_context.h b/paddle/fluid/pir/drr/include/drr_pattern_context.h index d656226a627f27..6a1dff9aaec0c6 100644 --- a/paddle/fluid/pir/drr/include/drr_pattern_context.h +++ b/paddle/fluid/pir/drr/include/drr_pattern_context.h @@ -283,39 +283,39 @@ class ResultPattern { Attribute StrAttr(const std::string& value) const { return ComputeAttr( - [&](const MatchContext& match_ctx) -> std::string { return value; }); + [=](const MatchContext& match_ctx) -> std::string { return value; }); } Attribute BoolAttr(bool value) const { return ComputeAttr( - [&](const MatchContext& match_ctx) -> bool { return value; }); + [=](const MatchContext& match_ctx) -> bool { return value; }); } Attribute Int32Attr(int32_t value) const { return ComputeAttr( - [&](const MatchContext& match_ctx) -> int32_t { return value; }); + [=](const MatchContext& match_ctx) -> int32_t { return value; }); } Attribute Int64Attr(int64_t value) const { return ComputeAttr( - [&](const MatchContext& match_ctx) -> int64_t { return value; }); + [=](const MatchContext& match_ctx) -> int64_t { return value; }); } Attribute Float32Attr(float value) const { return ComputeAttr( - [&](const MatchContext& match_ctx) -> float { return value; }); + [=](const MatchContext& match_ctx) -> float { return value; }); } Attribute VectorInt64Attr(const std::vector& value) const { return ComputeAttr( - [&](const MatchContext& match_ctx) -> std::vector { + [=](const MatchContext& match_ctx) -> std::vector { return value; }); } Attribute VectorInt32Attr(const std::vector& value) const { return ComputeAttr( - [&](const MatchContext& match_ctx) -> std::vector { + [=](const MatchContext& match_ctx) -> std::vector { return value; }); } diff --git a/test/cpp/pir/pattern_rewrite/drr_attention_fuse_test.cc b/test/cpp/pir/pattern_rewrite/drr_attention_fuse_test.cc index 373599479434f7..b1bacef15b037c 100644 --- a/test/cpp/pir/pattern_rewrite/drr_attention_fuse_test.cc +++ b/test/cpp/pir/pattern_rewrite/drr_attention_fuse_test.cc @@ -28,6 +28,10 @@ #include "paddle/pir/core/builtin_dialect.h" #include "paddle/pir/pass/pass_manager.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_DECLARE_KERNEL(multihead_matmul, GPU, ALL_LAYOUT); + void BuildProgram(pir::Builder &builder) { // NOLINT paddle::dialect::FullOp matmul_1_in_1 = builder.Build(std::vector{1, 300, 256},