From 56c7e62b011642c428995cc1cb4b2a4532671940 Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Thu, 28 Dec 2023 05:26:56 +0000 Subject: [PATCH] [DRR] change namespace pir::drr:: to paddle::drr:: --- .../operator/transforms/pd_to_cinn_pass.cc | 40 +-- .../op_generator/op_creator_drr_gen.py | 8 +- paddle/fluid/pir/drr/README.md | 24 +- paddle/fluid/pir/drr/README_cn.md | 24 +- paddle/fluid/pir/drr/api/drr_pattern_base.h | 6 +- .../fluid/pir/drr/api/drr_pattern_context.cc | 5 +- .../fluid/pir/drr/api/drr_pattern_context.h | 4 +- paddle/fluid/pir/drr/api/match_context.cc | 4 +- paddle/fluid/pir/drr/api/match_context.h | 4 +- paddle/fluid/pir/drr/api/tensor_interface.cc | 4 +- paddle/fluid/pir/drr/api/tensor_interface.h | 4 +- paddle/fluid/pir/drr/attr_type_uilts.h | 20 +- paddle/fluid/pir/drr/drr_rewrite_pattern.cc | 42 +-- paddle/fluid/pir/drr/drr_rewrite_pattern.h | 11 +- paddle/fluid/pir/drr/ir_operation.h | 4 +- paddle/fluid/pir/drr/ir_operation_factory.cc | 24 +- paddle/fluid/pir/drr/ir_operation_factory.h | 8 +- paddle/fluid/pir/drr/ir_value.h | 8 +- paddle/fluid/pir/drr/match_context_impl.h | 4 +- paddle/fluid/pir/drr/pattern_graph.cc | 4 +- paddle/fluid/pir/drr/pattern_graph.h | 4 +- .../transforms/fusion/attention_fuse_pass.cc | 50 ++-- .../transforms/fusion/conv2d_add_fuse_pass.cc | 18 +- .../fc_elementwise_layernorm_fuse_pass.cc | 32 ++- .../pir/transforms/fusion/fc_fuse_pass.cc | 33 +-- .../fusion/fc_with_special_op_fuse_pass.cc | 68 +++-- .../fused_dot_product_attention_pass.cc | 250 ++++++++++-------- .../fusion/fused_dropout_add_pass.cc | 16 +- .../fusion/fused_gemm_epilogue_pass.cc | 75 +++--- .../fused_linear_param_grad_add_pass.cc | 132 +++++---- .../fusion/fused_weight_only_linear_pass.cc | 64 ++--- .../fusion/matmul_scale_fuse_pass.cc | 24 +- .../pir/transforms/identity_op_clean_pass.cc | 62 ++--- .../drr_same_type_binding_test.cc | 8 +- test/cpp/pir/pattern_rewrite/drr_test.cc | 38 +-- 35 files changed, 597 insertions(+), 529 deletions(-) diff --git a/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc index 295c50b0eae00e..352fd9fdde322b 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc @@ -31,11 +31,11 @@ namespace cinn { namespace dialect { namespace ir { -class SumOpPattern : public pir::drr::DrrPatternBase { +class SumOpPattern : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { + void operator()(paddle::drr::DrrPatternContext *ctx) const override { // Source Pattern - pir::drr::SourcePattern pattern = ctx->SourcePattern(); + paddle::drr::SourcePattern pattern = ctx->SourcePattern(); const auto &full_int_array = pattern.Op(paddle::dialect::FullIntArrayOp::name(), {{"value", pattern.Attr("axis_info")}, @@ -48,7 +48,7 @@ class SumOpPattern : public pir::drr::DrrPatternBase { pattern.Tensor("ret") = sum(pattern.Tensor("arg0"), full_int_array()); // Result patterns - pir::drr::ResultPattern res = pattern.ResultPattern(); + paddle::drr::ResultPattern res = pattern.ResultPattern(); const auto &cinn_reduce_sum = res.Op(cinn::dialect::ReduceSumOp::name(), {{"dim", pattern.Attr("axis_info")}, @@ -57,11 +57,11 @@ class SumOpPattern : public pir::drr::DrrPatternBase { } }; -class MaxOpPattern : public pir::drr::DrrPatternBase { +class MaxOpPattern : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { + void operator()(paddle::drr::DrrPatternContext *ctx) const override { // Source Pattern - pir::drr::SourcePattern pattern = ctx->SourcePattern(); + paddle::drr::SourcePattern pattern = ctx->SourcePattern(); const auto &full_int_array = pattern.Op(paddle::dialect::FullIntArrayOp::name(), {{"value", pattern.Attr("axis_info")}, @@ -73,7 +73,7 @@ class MaxOpPattern : public pir::drr::DrrPatternBase { pattern.Tensor("ret") = pd_max(pattern.Tensor("arg0"), full_int_array()); // Result patterns - pir::drr::ResultPattern res = pattern.ResultPattern(); + paddle::drr::ResultPattern res = pattern.ResultPattern(); const auto &cinn_reduce_max = res.Op(cinn::dialect::ReduceMaxOp::name(), {{"dim", pattern.Attr("axis_info")}, @@ -82,11 +82,11 @@ class MaxOpPattern : public pir::drr::DrrPatternBase { } }; -class MinOpPattern : public pir::drr::DrrPatternBase { +class MinOpPattern : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { + void operator()(paddle::drr::DrrPatternContext *ctx) const override { // Source Pattern - pir::drr::SourcePattern pattern = ctx->SourcePattern(); + paddle::drr::SourcePattern pattern = ctx->SourcePattern(); const auto &full_int_array = pattern.Op(paddle::dialect::FullIntArrayOp::name(), {{"value", pattern.Attr("axis_info")}, @@ -98,7 +98,7 @@ class MinOpPattern : public pir::drr::DrrPatternBase { pattern.Tensor("ret") = pd_max(pattern.Tensor("arg0"), full_int_array()); // Result patterns - pir::drr::ResultPattern res = pattern.ResultPattern(); + paddle::drr::ResultPattern res = pattern.ResultPattern(); const auto &cinn_reduce_max = res.Op(cinn::dialect::ReduceMinOp::name(), {{"dim", pattern.Attr("axis_info")}, @@ -107,11 +107,11 @@ class MinOpPattern : public pir::drr::DrrPatternBase { } }; -class ProdOpPattern : public pir::drr::DrrPatternBase { +class ProdOpPattern : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { + void operator()(paddle::drr::DrrPatternContext *ctx) const override { // Source Pattern - pir::drr::SourcePattern pattern = ctx->SourcePattern(); + paddle::drr::SourcePattern pattern = ctx->SourcePattern(); const auto &full_int_array = pattern.Op(paddle::dialect::FullIntArrayOp::name(), {{"value", pattern.Attr("axis_info")}, @@ -123,7 +123,7 @@ class ProdOpPattern : public pir::drr::DrrPatternBase { pattern.Tensor("ret") = pd_max(pattern.Tensor("arg0"), full_int_array()); // Result patterns - pir::drr::ResultPattern res = pattern.ResultPattern(); + paddle::drr::ResultPattern res = pattern.ResultPattern(); const auto &cinn_reduce_max = res.Op(cinn::dialect::ReduceProdOp::name(), {{"dim", pattern.Attr("axis_info")}, @@ -552,11 +552,11 @@ class SplitWithNumOpPattern } }; -class UniformOpPattern : public pir::drr::DrrPatternBase { +class UniformOpPattern : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { + void operator()(paddle::drr::DrrPatternContext *ctx) const override { // Source Pattern - pir::drr::SourcePattern pattern = ctx->SourcePattern(); + paddle::drr::SourcePattern pattern = ctx->SourcePattern(); const auto &full_int_array = pattern.Op(paddle::dialect::FullIntArrayOp::name(), {{"value", pattern.Attr("axis_info")}, @@ -585,7 +585,7 @@ class UniformOpPattern : public pir::drr::DrrPatternBase { // int64_t[] shape, float min, float max, int seed, DataType dtype, int // diag_num, int diag_step, float diag_val) // Result patterns - pir::drr::ResultPattern res = pattern.ResultPattern(); + paddle::drr::ResultPattern res = pattern.ResultPattern(); const auto &cinn_uniform = res.Op(cinn::dialect::UniformRandomOp::name(), {{"shape", pattern.Attr("axis_info")}, diff --git a/paddle/fluid/pir/dialect/op_generator/op_creator_drr_gen.py b/paddle/fluid/pir/dialect/op_generator/op_creator_drr_gen.py index 9a40f74429e52b..18dc70f9fa7a7c 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_creator_drr_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_creator_drr_gen.py @@ -27,7 +27,7 @@ {op_header} #include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" -namespace pir {{ +namespace paddle {{ namespace drr {{ void OperationFactory::Register{dialect}GeneratedOpCreator() {{ @@ -35,14 +35,14 @@ }} }} // namespace drr -}} // namespace pir +}} // namespace paddle """ NORMAL_FUNCTION_TEMPLATE = """ RegisterOperationCreator( "{op_name}", - [](const std::vector& inputs, + [](const std::vector& inputs, const pir::AttributeMap& attrs, pir::PatternRewriter& rewriter) {{ return rewriter.Build<{namespace}::{op_class_name}>( @@ -53,7 +53,7 @@ MUTABLE_ATTR_FUNCTION_TEMPLATE = """ RegisterOperationCreator( "{op_name}", - [](const std::vector& inputs, + [](const std::vector& inputs, const pir::AttributeMap& attrs, pir::PatternRewriter& rewriter) {{ // mutable_attr is tensor diff --git a/paddle/fluid/pir/drr/README.md b/paddle/fluid/pir/drr/README.md index 4abdbb1b647179..6fbac0756ae865 100644 --- a/paddle/fluid/pir/drr/README.md +++ b/paddle/fluid/pir/drr/README.md @@ -10,9 +10,9 @@ Taking PASS to eliminate redundant CastOp as an example, the code example develo ~~~ c++ // 1. Inherit specialized template class from DrPatternBase class RemoveRedundentCastPattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { // 2. Overload operator() - void operator()(pir::drr::DrrPatternContext *ctx) const override { + void operator()(paddle::drr::DrrPatternContext *ctx) const override { // 3. Define a SourcePattern containing two consecutive CastOps using Op, Tensor, and Attribute auto pat = ctx->SourcePattern(); @@ -55,7 +55,7 @@ Developers only need to define `SourcePattern`, `Constrains` and `ResultPattern` DrrPatternBase
 virtual void operator()(
-        pir::drr::DrrPatternContext* ctx) const 
+ paddle::drr::DrrPatternContext* ctx) const Implement the entry function of DRR PASS ctx: Context parameters required to create Patten @@ -165,11 +165,11 @@ Attribute Attr(const AttrComputeFunc& attr_compute_func) const ## 3 Example Example 1: Matmul + Add -> FusedGemmEpilogue ~~~ c++ -class FusedLinearPattern : public pir::drr::DrrPatternBase { +class FusedLinearPattern : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { + void operator()(paddle::drr::DrrPatternContext *ctx) const override { // Define SourcePattern - pir::drr::SourcePattern pat = ctx->SourcePattern(); + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &matmul = pat.Op(paddle::dialect::MatmulOp::name(), {{"transpose_x", pat.Attr("trans_x")}, {"transpose_y", pat.Attr("trans_y")}}); @@ -179,10 +179,10 @@ class FusedLinearPattern : public pir::drr::DrrPatternBase { pat.Tensor("out") = add(pat.Tensor("tmp"), pat.Tensor("bias")); // Define ResultPattern - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); // Define Constrain const auto &act_attr = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any { return "none"; }); const auto &fused_gemm_epilogue = res.Op(paddle::dialect::FusedGemmEpilogueOp::name(), @@ -199,11 +199,11 @@ class FusedLinearPattern : public pir::drr::DrrPatternBase { Example 2: Full + Expand -> Full ~~~ c++ class FoldExpandToConstantPattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { + void operator()(paddle::drr::DrrPatternContext *ctx) const override { // Define SourcePattern - pir::drr::SourcePattern pat = ctx->SourcePattern(); + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &full1 = pat.Op(paddle::dialect::FullOp::name(), {{"shape", pat.Attr("shape_1")}, {"value", pat.Attr("value_1")}, @@ -218,7 +218,7 @@ class FoldExpandToConstantPattern pat.Tensor("ret") = expand(full1(), full_int_array1()); // Define ResultPattern - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &full2 = res.Op(paddle::dialect::FullOp::name(), {{"shape", pat.Attr("expand_shape_value")}, {"value", pat.Attr("value_1")}, diff --git a/paddle/fluid/pir/drr/README_cn.md b/paddle/fluid/pir/drr/README_cn.md index 456bf7921414bf..1291bec2954c48 100644 --- a/paddle/fluid/pir/drr/README_cn.md +++ b/paddle/fluid/pir/drr/README_cn.md @@ -10,9 +10,9 @@ DRR ( Declarative Rewrite Rule ) 是来处理这种 DAG-to-DAG 类型的一套 P ~~~ c++ // 1. 继承 DrrPatternBase 的特化模板类 class RemoveRedundentCastPattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { // 2. 重载 operator() - void operator()(pir::drr::DrrPatternContext *ctx) const override { + void operator()(paddle::drr::DrrPatternContext *ctx) const override { // 3. 使用 Op、Tensor 和 Attribute 定义一个包含两个连续 CastOp 的 SourcePattern auto pat = ctx->SourcePattern(); @@ -56,7 +56,7 @@ DRR PASS 包含以下三个部分: DrrPatternBase
 virtual void operator()(
-        pir::drr::DrrPatternContext* ctx) const 
+ paddle::drr::DrrPatternContext* ctx) const 实现 DRR PASS 的入口函数 ctx: 创建 Patten 所需要的 Context 参数 @@ -168,11 +168,11 @@ Attribute Attr(const AttrComputeFunc& attr_compute_func) const ## 3 使用示例 Example 1: Matmul + Add -> FusedGemmEpilogue ~~~ c++ -class FusedLinearPattern : public pir::drr::DrrPatternBase { +class FusedLinearPattern : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { + void operator()(paddle::drr::DrrPatternContext *ctx) const override { // 定义 Source Pattern - pir::drr::SourcePattern pat = ctx->SourcePattern(); + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &matmul = pat.Op(paddle::dialect::MatmulOp::name(), {{"transpose_x", pat.Attr("trans_x")}, {"transpose_y", pat.Attr("trans_y")}}); @@ -182,10 +182,10 @@ class FusedLinearPattern : public pir::drr::DrrPatternBase { pat.Tensor("out") = add(pat.Tensor("tmp"), pat.Tensor("bias")); // 定义 Result Pattern - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); // 定义 Constrain const auto &act_attr = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any { return "none"; }); const auto &fused_gemm_epilogue = res.Op(paddle::dialect::FusedGemmEpilogueOp::name(), @@ -202,11 +202,11 @@ class FusedLinearPattern : public pir::drr::DrrPatternBase { Example 2: Full + Expand -> Full ~~~ c++ class FoldExpandToConstantPattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { + void operator()(paddle::drr::DrrPatternContext *ctx) const override { // 定义 Source Pattern - pir::drr::SourcePattern pat = ctx->SourcePattern(); + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &full1 = pat.Op(paddle::dialect::FullOp::name(), {{"shape", pat.Attr("shape_1")}, {"value", pat.Attr("value_1")}, @@ -221,7 +221,7 @@ class FoldExpandToConstantPattern pat.Tensor("ret") = expand(full1(), full_int_array1()); // 定义 Result Pattern Constrains: 本 Pass 无额外约束规则 - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &full2 = res.Op(paddle::dialect::FullOp::name(), {{"shape", pat.Attr("expand_shape_value")}, {"value", pat.Attr("value_1")}, diff --git a/paddle/fluid/pir/drr/api/drr_pattern_base.h b/paddle/fluid/pir/drr/api/drr_pattern_base.h index 1a84c42800373b..18252d536869f7 100644 --- a/paddle/fluid/pir/drr/api/drr_pattern_base.h +++ b/paddle/fluid/pir/drr/api/drr_pattern_base.h @@ -17,7 +17,7 @@ #include "paddle/fluid/pir/drr/api/drr_pattern_context.h" #include "paddle/fluid/pir/drr/drr_rewrite_pattern.h" -namespace pir { +namespace paddle { namespace drr { template @@ -26,7 +26,7 @@ class DrrPatternBase { virtual ~DrrPatternBase() = default; // Define the Drr Pattern. - virtual void operator()(pir::drr::DrrPatternContext* ctx) const = 0; + virtual void operator()(paddle::drr::DrrPatternContext* ctx) const = 0; std::unique_ptr Build( pir::IrContext* ir_context, pir::PatternBenefit benefit = 1) const { @@ -39,4 +39,4 @@ class DrrPatternBase { }; } // namespace drr -} // namespace pir +} // namespace paddle diff --git a/paddle/fluid/pir/drr/api/drr_pattern_context.cc b/paddle/fluid/pir/drr/api/drr_pattern_context.cc index 50e94c3458265c..7f98f0b34cbeb7 100644 --- a/paddle/fluid/pir/drr/api/drr_pattern_context.cc +++ b/paddle/fluid/pir/drr/api/drr_pattern_context.cc @@ -17,7 +17,7 @@ #include "paddle/fluid/pir/drr/pattern_graph.h" #include "paddle/phi/core/enforce.h" -namespace pir { +namespace paddle { namespace drr { DrrPatternContext::DrrPatternContext() { @@ -28,6 +28,7 @@ DrrPatternContext::DrrPatternContext() { drr::SourcePattern DrrPatternContext::SourcePattern() { return drr::SourcePattern(this); } + const Op& DrrPatternContext::SourceOpPattern( const std::string& op_type, const std::unordered_map& attributes) { @@ -167,4 +168,4 @@ void Tensor::operator=(const Tensor& other) const { // NOLINT } } // namespace drr -} // namespace pir +} // namespace paddle diff --git a/paddle/fluid/pir/drr/api/drr_pattern_context.h b/paddle/fluid/pir/drr/api/drr_pattern_context.h index 5c235215dd19ba..feb0e988aa8822 100644 --- a/paddle/fluid/pir/drr/api/drr_pattern_context.h +++ b/paddle/fluid/pir/drr/api/drr_pattern_context.h @@ -24,7 +24,7 @@ #include "paddle/fluid/pir/drr/api/match_context.h" -namespace pir { +namespace paddle { namespace drr { class Op; @@ -334,4 +334,4 @@ class SourcePattern { }; } // namespace drr -} // namespace pir +} // namespace paddle diff --git a/paddle/fluid/pir/drr/api/match_context.cc b/paddle/fluid/pir/drr/api/match_context.cc index 35b28db13254ed..e5f15adf72e75e 100644 --- a/paddle/fluid/pir/drr/api/match_context.cc +++ b/paddle/fluid/pir/drr/api/match_context.cc @@ -19,7 +19,7 @@ #include "paddle/fluid/pir/drr/ir_operation.h" #include "paddle/fluid/pir/drr/match_context_impl.h" -namespace pir { +namespace paddle { namespace drr { MatchContext::MatchContext(std::shared_ptr impl) @@ -46,4 +46,4 @@ template std::vector MatchContext::Attr>( const std::string&) const; } // namespace drr -} // namespace pir +} // namespace paddle diff --git a/paddle/fluid/pir/drr/api/match_context.h b/paddle/fluid/pir/drr/api/match_context.h index a1699ccb5bddf6..762c86cf8a8e60 100644 --- a/paddle/fluid/pir/drr/api/match_context.h +++ b/paddle/fluid/pir/drr/api/match_context.h @@ -20,7 +20,7 @@ #include "paddle/fluid/pir/drr/api/tensor_interface.h" #include "paddle/fluid/pir/drr/ir_operation.h" -namespace pir { +namespace paddle { namespace drr { class TensorInterface; @@ -40,4 +40,4 @@ class MatchContext final { }; } // namespace drr -} // namespace pir +} // namespace paddle diff --git a/paddle/fluid/pir/drr/api/tensor_interface.cc b/paddle/fluid/pir/drr/api/tensor_interface.cc index 03a35031f0d917..335f95214887a9 100644 --- a/paddle/fluid/pir/drr/api/tensor_interface.cc +++ b/paddle/fluid/pir/drr/api/tensor_interface.cc @@ -15,7 +15,7 @@ #include "paddle/fluid/pir/drr/api/tensor_interface.h" #include "paddle/fluid/pir/drr/ir_value.h" -namespace pir { +namespace paddle { namespace drr { bool ShapeInterface::operator==(const ShapeInterface& other) const { @@ -33,4 +33,4 @@ bool DtypeInterface::operator==(const DtypeInterface& other) const { IrDtype DtypeInterface::get() const { return *(this->dtype_); } } // namespace drr -} // namespace pir +} // namespace paddle diff --git a/paddle/fluid/pir/drr/api/tensor_interface.h b/paddle/fluid/pir/drr/api/tensor_interface.h index 4684beba4ad844..24774f00d5a298 100644 --- a/paddle/fluid/pir/drr/api/tensor_interface.h +++ b/paddle/fluid/pir/drr/api/tensor_interface.h @@ -16,7 +16,7 @@ #include -namespace pir { +namespace paddle { namespace drr { class IrValue; @@ -60,4 +60,4 @@ class TensorInterface { }; } // namespace drr -} // namespace pir +} // namespace paddle diff --git a/paddle/fluid/pir/drr/attr_type_uilts.h b/paddle/fluid/pir/drr/attr_type_uilts.h index 4043aa3c643835..8904ed0e9ff6a7 100644 --- a/paddle/fluid/pir/drr/attr_type_uilts.h +++ b/paddle/fluid/pir/drr/attr_type_uilts.h @@ -19,7 +19,7 @@ #include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" #include "paddle/pir/core/builtin_attribute.h" -namespace pir { +namespace paddle { namespace drr { template @@ -32,11 +32,11 @@ struct CppTypeToIrAttribute; using type = ir_attr_type; \ }; -PD_SPECIALIZE_CppTypeToIrAttribute(bool, BoolAttribute); -PD_SPECIALIZE_CppTypeToIrAttribute(int32_t, Int32Attribute); -PD_SPECIALIZE_CppTypeToIrAttribute(int64_t, Int64Attribute); -PD_SPECIALIZE_CppTypeToIrAttribute(float, FloatAttribute); -PD_SPECIALIZE_CppTypeToIrAttribute(std::string, StrAttribute); +PD_SPECIALIZE_CppTypeToIrAttribute(bool, pir::BoolAttribute); +PD_SPECIALIZE_CppTypeToIrAttribute(int32_t, pir::Int32Attribute); +PD_SPECIALIZE_CppTypeToIrAttribute(int64_t, pir::Int64Attribute); +PD_SPECIALIZE_CppTypeToIrAttribute(float, pir::FloatAttribute); +PD_SPECIALIZE_CppTypeToIrAttribute(std::string, pir::StrAttribute); PD_SPECIALIZE_CppTypeToIrAttribute(phi::DataType, paddle::dialect::DataTypeAttribute); PD_SPECIALIZE_CppTypeToIrAttribute(phi::Place, paddle::dialect::PlaceAttribute); @@ -61,7 +61,8 @@ struct IrAttrbuteCreator> { std::vector attr_vec; attr_vec.reserve(obj.size()); for (int32_t x : obj) { - attr_vec.push_back(Int32Attribute::get(pir::IrContext::Instance(), x)); + attr_vec.push_back( + pir::Int32Attribute::get(pir::IrContext::Instance(), x)); } return pir::ArrayAttribute::get(pir::IrContext::Instance(), attr_vec); } @@ -73,7 +74,8 @@ struct IrAttrbuteCreator> { std::vector attr_vec; attr_vec.reserve(obj.size()); for (float x : obj) { - attr_vec.push_back(FloatAttribute::get(pir::IrContext::Instance(), x)); + attr_vec.push_back( + pir::FloatAttribute::get(pir::IrContext::Instance(), x)); } return pir::ArrayAttribute::get(pir::IrContext::Instance(), attr_vec); } @@ -140,4 +142,4 @@ struct IrAttrTypeCast> { }; } // namespace drr -} // namespace pir +} // namespace paddle diff --git a/paddle/fluid/pir/drr/drr_rewrite_pattern.cc b/paddle/fluid/pir/drr/drr_rewrite_pattern.cc index d0c0d71a3feaab..d408c1aab13490 100644 --- a/paddle/fluid/pir/drr/drr_rewrite_pattern.cc +++ b/paddle/fluid/pir/drr/drr_rewrite_pattern.cc @@ -14,12 +14,12 @@ #include "paddle/fluid/pir/drr/drr_rewrite_pattern.h" -namespace pir { +namespace paddle { namespace drr { bool DrrRewritePattern::MatchAndRewrite( pir::Operation* op, - PatternRewriter& rewriter) const { // NOLINT + pir::PatternRewriter& rewriter) const { // NOLINT std::shared_ptr src_match_ctx = std::make_shared(); if (PatternGraphMatch(op, src_match_ctx.get())) { @@ -41,8 +41,8 @@ bool DrrRewritePattern::PatternGraphMatch( return false; } std::vector drr_output_sequence; - std::vector ir_output_sequence; - std::unordered_map output_op_map; + std::vector ir_output_sequence; + std::unordered_map output_op_map; for (const auto& pair : bind_map) { drr_output_sequence.push_back(pair.first); } @@ -50,8 +50,8 @@ bool DrrRewritePattern::PatternGraphMatch( auto permute = [&](auto&& permute, size_t index) -> bool { if (index == drr_output_sequence.size()) { // avoiding duplicate binding of ir op - std::unordered_set ir_output_set; - for (Operation* op : ir_output_sequence) { + std::unordered_set ir_output_set; + for (pir::Operation* op : ir_output_sequence) { auto pr = ir_output_set.insert(op); if (pr.second == false) { return false; @@ -64,7 +64,7 @@ bool DrrRewritePattern::PatternGraphMatch( drr_output_sequence.end(), ir_output_sequence.begin(), std::inserter(output_op_map, output_op_map.end()), - [](const OpCall* drr_op, Operation* ir_op) { + [](const OpCall* drr_op, pir::Operation* ir_op) { return std::make_pair(drr_op, ir_op); }); if (MatchFromOutputToInput( @@ -214,12 +214,12 @@ void DrrRewritePattern::DfsVisitor( } bool DrrRewritePattern::MatchFromOutputToInput( - std::unordered_map output_op_map, + std::unordered_map output_op_map, const SourcePatternGraph& source_pattern_graph, const std::shared_ptr& source_pattern_match_ctx) const { VLOG(6) << "MatchFromOutputToInput Start"; std::unordered_set drr_visited; - std::unordered_set ir_visited; + std::unordered_set ir_visited; std::queue drr_q; std::queue ir_q; bool matched = true; @@ -385,8 +385,8 @@ MatchContextImpl DrrRewritePattern::CreateOperations( } } - std::vector> temp_program; - std::unordered_map op_2_temp_program_index; + std::vector> temp_program; + std::unordered_map op_2_temp_program_index; for (auto& op : *rewriter.block()) { op_2_temp_program_index[&op] = temp_program.size(); temp_program.push_back({&op}); @@ -397,14 +397,14 @@ MatchContextImpl DrrRewritePattern::CreateOperations( graph_topo_visit.WalkGraphNodesTopoOrder([&](const OpCall& op_call) { // set insert point size_t max_input_op_index = 0; - Operation* max_index_op = nullptr; + pir::Operation* max_index_op = nullptr; for (const Tensor* input : op_call.inputs()) { if (input->is_none()) { continue; } auto ir_val = res_match_ctx.GetIrValue(input->name()); if (ir_val) { - Operation* ir_input_op = ir_val.dyn_cast().owner(); + pir::Operation* ir_input_op = ir_val.dyn_cast().owner(); if (op_2_temp_program_index.count(ir_input_op) == 0) { max_input_op_index = 0UL; } else if (max_input_op_index < @@ -431,7 +431,7 @@ MatchContextImpl DrrRewritePattern::CreateOperations( } if (max_input_op_index == 0UL) { VLOG(6) << "Not found producer op for (" << op_call.name() << ")"; - Operation* source_patter_first_op = + pir::Operation* source_patter_first_op = src_match_ctx.Operation(source_pattern_graph.owned_op_call()[0].get()) .get(); max_input_op_index = op_2_temp_program_index[source_patter_first_op]; @@ -440,7 +440,7 @@ MatchContextImpl DrrRewritePattern::CreateOperations( rewriter.SetInsertionPointAfter(max_index_op); } - Operation* new_op = + pir::Operation* new_op = CreateOperation(op_call, src_match_ctx, rewriter, &res_match_ctx); op_2_temp_program_index[new_op] = max_input_op_index + 1; if (max_input_op_index + 1 >= temp_program.size()) { @@ -487,11 +487,11 @@ void DrrRewritePattern::DeleteSourcePatternOp( const ResultPatternGraph& result_pattern_graph, const MatchContextImpl& src_match_ctx, pir::PatternRewriter& rewriter) const { // NOLINT - std::queue delete_ops_que; - std::unordered_set delete_ops_set; + std::queue delete_ops_que; + std::unordered_set delete_ops_set; GraphTopo graph_topo_visit(&source_pattern_graph); graph_topo_visit.WalkGraphNodesTopoOrder([&](const OpCall& op_call) { - Operation* op = src_match_ctx.Operation(&op_call).get(); + pir::Operation* op = src_match_ctx.Operation(&op_call).get(); VLOG(5) << "DRR delete op: " << op->name() << " pointer: " << op; if (delete_ops_set.count(op) == 0 && op->use_empty()) { delete_ops_que.push(op); @@ -500,9 +500,9 @@ void DrrRewritePattern::DeleteSourcePatternOp( }); while (!delete_ops_que.empty()) { - Operation* op = delete_ops_que.front(); + pir::Operation* op = delete_ops_que.front(); delete_ops_que.pop(); - std::vector inputs = op->operands_source(); + std::vector inputs = op->operands_source(); VLOG(5) << "Delete (" << op->name() << " @" << op << ") in source_pattern_graph."; rewriter.EraseOp(op); @@ -517,4 +517,4 @@ void DrrRewritePattern::DeleteSourcePatternOp( } } // namespace drr -} // namespace pir +} // namespace paddle diff --git a/paddle/fluid/pir/drr/drr_rewrite_pattern.h b/paddle/fluid/pir/drr/drr_rewrite_pattern.h index 5d20a5947f13b0..6163c6d9d0193e 100644 --- a/paddle/fluid/pir/drr/drr_rewrite_pattern.h +++ b/paddle/fluid/pir/drr/drr_rewrite_pattern.h @@ -31,7 +31,7 @@ #include "paddle/pir/core/type_name.h" #include "paddle/pir/pattern_rewrite/pattern_match.h" -namespace pir { +namespace paddle { namespace drr { class DrrRewritePattern : public pir::RewritePattern { @@ -57,8 +57,9 @@ class DrrRewritePattern : public pir::RewritePattern { "source pattern definition code.")); } - bool MatchAndRewrite(pir::Operation* op, - PatternRewriter& rewriter) const override; // // NOLINT + bool MatchAndRewrite( + pir::Operation* op, + pir::PatternRewriter& rewriter) const override; // // NOLINT private: bool PatternGraphMatch(pir::Operation* op, @@ -78,7 +79,7 @@ class DrrRewritePattern : public pir::RewritePattern { output_op_bind_map) const; bool MatchFromOutputToInput( - std::unordered_map output_op_map, + std::unordered_map output_op_map, const SourcePatternGraph& source_pattern_graph, const std::shared_ptr& source_pattern_match_ctx) const; @@ -113,4 +114,4 @@ class DrrRewritePattern : public pir::RewritePattern { }; } // namespace drr -} // namespace pir +} // namespace paddle diff --git a/paddle/fluid/pir/drr/ir_operation.h b/paddle/fluid/pir/drr/ir_operation.h index 2764bc92454170..a88bb3bfff97cf 100644 --- a/paddle/fluid/pir/drr/ir_operation.h +++ b/paddle/fluid/pir/drr/ir_operation.h @@ -16,7 +16,7 @@ #include "paddle/pir/core/operation.h" -namespace pir { +namespace paddle { namespace drr { class IrOperation { @@ -30,4 +30,4 @@ class IrOperation { }; } // namespace drr -} // namespace pir +} // namespace paddle diff --git a/paddle/fluid/pir/drr/ir_operation_factory.cc b/paddle/fluid/pir/drr/ir_operation_factory.cc index 6644026fabde01..bbc31e9df7c25b 100644 --- a/paddle/fluid/pir/drr/ir_operation_factory.cc +++ b/paddle/fluid/pir/drr/ir_operation_factory.cc @@ -24,13 +24,13 @@ #include "paddle/pir/core/operation.h" #include "paddle/pir/core/value.h" -namespace pir { +namespace paddle { namespace drr { void OperationFactory::RegisterManualOpCreator() { RegisterOperationCreator( "pd_op.fused_gemm_epilogue", - [](const std::vector& inputs, + [](const std::vector& inputs, const pir::AttributeMap& attrs, pir::PatternRewriter& rewriter) { return rewriter.Build( @@ -41,7 +41,7 @@ void OperationFactory::RegisterManualOpCreator() { }); RegisterOperationCreator( "pd_op.fused_gemm_epilogue_grad", - [](const std::vector& inputs, + [](const std::vector& inputs, const pir::AttributeMap& attrs, pir::PatternRewriter& rewriter) { return rewriter.Build( @@ -52,14 +52,14 @@ void OperationFactory::RegisterManualOpCreator() { attrs); }); RegisterOperationCreator("builtin.combine", - [](const std::vector& inputs, + [](const std::vector& inputs, const pir::AttributeMap& attrs, pir::PatternRewriter& rewriter) { return rewriter.Build(inputs); }); RegisterOperationCreator( "pd_op.scale", - [](const std::vector& inputs, + [](const std::vector& inputs, const pir::AttributeMap& attrs, pir::PatternRewriter& rewriter) { return rewriter.Build( @@ -130,18 +130,18 @@ pir::AttributeMap CreateAttributeMap(const OpCall& op_call, return attr_map; } -Value GetIrValueByDrrTensor(const Tensor& tensor, - const MatchContextImpl& res_match_ctx) { +pir::Value GetIrValueByDrrTensor(const Tensor& tensor, + const MatchContextImpl& res_match_ctx) { if (tensor.is_none()) { - return Value{}; + return pir::Value{}; } return res_match_ctx.GetIrValue(tensor.name()).get(); } -std::vector GetIrValuesByDrrTensors( +std::vector GetIrValuesByDrrTensors( const std::vector& tensors, const MatchContextImpl& res_match_ctx) { - std::vector ir_values; + std::vector ir_values; ir_values.reserve(tensors.size()); for (const auto* tensor : tensors) { ir_values.push_back(GetIrValueByDrrTensor(*tensor, res_match_ctx)); @@ -167,7 +167,7 @@ pir::Operation* CreateOperation(const OpCall& op_call, MatchContextImpl* res_match_ctx) { VLOG(6) << "Drr create [" << op_call.name() << "] op..."; const auto& inputs = op_call.inputs(); - std::vector ir_values = + std::vector ir_values = GetIrValuesByDrrTensors(inputs, *res_match_ctx); pir::Operation* op = OperationFactory::Instance().CreateOperation( op_call.name(), @@ -180,4 +180,4 @@ pir::Operation* CreateOperation(const OpCall& op_call, } } // namespace drr -} // namespace pir +} // namespace paddle diff --git a/paddle/fluid/pir/drr/ir_operation_factory.h b/paddle/fluid/pir/drr/ir_operation_factory.h index adc76efb99b2de..40682904df62a8 100644 --- a/paddle/fluid/pir/drr/ir_operation_factory.h +++ b/paddle/fluid/pir/drr/ir_operation_factory.h @@ -20,7 +20,7 @@ #include "paddle/fluid/pir/drr/match_context_impl.h" #include "paddle/pir/pattern_rewrite/pattern_match.h" -namespace pir { +namespace paddle { namespace drr { class OperationFactory { @@ -31,7 +31,7 @@ class OperationFactory { } using operation_create_fn = - std::function&, + std::function&, const pir::AttributeMap&, pir::PatternRewriter&)>; @@ -42,7 +42,7 @@ class OperationFactory { pir::Operation* CreateOperation( const std::string& op_name, - const std::vector& inputs, + const std::vector& inputs, const pir::AttributeMap& attrs, pir::PatternRewriter& rewriter) const { // NOLINT auto iter = op_creator_map.find(op_name); @@ -79,4 +79,4 @@ pir::Operation* CreateOperation(const OpCall& op_call, MatchContextImpl* res_match_ctx); } // namespace drr -} // namespace pir +} // namespace paddle diff --git a/paddle/fluid/pir/drr/ir_value.h b/paddle/fluid/pir/drr/ir_value.h index 125f198dcc74c5..ae99fd8c1964e2 100644 --- a/paddle/fluid/pir/drr/ir_value.h +++ b/paddle/fluid/pir/drr/ir_value.h @@ -21,7 +21,7 @@ #include "paddle/pir/core/type.h" #include "paddle/pir/core/value.h" -namespace pir { +namespace paddle { namespace drr { class IrShape { @@ -101,10 +101,10 @@ class IrValue : public TensorInterface { } // Don't use it in drr pass! - const Value& get() const { return value_; } + const pir::Value& get() const { return value_; } private: - const Value value_; + const pir::Value value_; const IrShape shape_; const IrDtype dtype_; }; @@ -112,4 +112,4 @@ class IrValue : public TensorInterface { class IrAttr; } // namespace drr -} // namespace pir +} // namespace paddle diff --git a/paddle/fluid/pir/drr/match_context_impl.h b/paddle/fluid/pir/drr/match_context_impl.h index 37b06914cd2bdf..b1234d81299360 100644 --- a/paddle/fluid/pir/drr/match_context_impl.h +++ b/paddle/fluid/pir/drr/match_context_impl.h @@ -25,7 +25,7 @@ #include "paddle/fluid/pir/drr/ir_value.h" #include "paddle/pir/core/builtin_attribute.h" -namespace pir { +namespace paddle { namespace drr { class MatchContextImpl final { @@ -131,4 +131,4 @@ class MatchContextImpl final { }; } // namespace drr -} // namespace pir +} // namespace paddle diff --git a/paddle/fluid/pir/drr/pattern_graph.cc b/paddle/fluid/pir/drr/pattern_graph.cc index 7d732b6576f68c..58c79c65acf2f6 100644 --- a/paddle/fluid/pir/drr/pattern_graph.cc +++ b/paddle/fluid/pir/drr/pattern_graph.cc @@ -19,7 +19,7 @@ #include "paddle/fluid/pir/drr/api/drr_pattern_context.h" #include "paddle/phi/core/enforce.h" -namespace pir { +namespace paddle { namespace drr { const drr::OpCall &PatternGraph::AddOpCall( @@ -238,4 +238,4 @@ std::ostream &operator<<(std::ostream &os, const PatternGraph &pattern_graph) { } } // namespace drr -} // namespace pir +} // namespace paddle diff --git a/paddle/fluid/pir/drr/pattern_graph.h b/paddle/fluid/pir/drr/pattern_graph.h index 63bd60eadf17f3..e5cd74b2fa2176 100644 --- a/paddle/fluid/pir/drr/pattern_graph.h +++ b/paddle/fluid/pir/drr/pattern_graph.h @@ -21,7 +21,7 @@ #include #include -namespace pir { +namespace paddle { namespace drr { class Constraint; @@ -105,4 +105,4 @@ class GraphTopo { }; } // namespace drr -} // namespace pir +} // namespace paddle diff --git a/paddle/fluid/pir/transforms/fusion/attention_fuse_pass.cc b/paddle/fluid/pir/transforms/fusion/attention_fuse_pass.cc index fbabf835390018..ab19247de4b26a 100644 --- a/paddle/fluid/pir/transforms/fusion/attention_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/attention_fuse_pass.cc @@ -22,13 +22,13 @@ namespace { class MultiHeadMatmulFusePattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { + void operator()(paddle::drr::DrrPatternContext *ctx) const override { // // Source Pattern. // - pir::drr::SourcePattern src = ctx->SourcePattern(); + paddle::drr::SourcePattern src = ctx->SourcePattern(); // The first path to matmul with scale (q). const auto &matmul_1 = src.Op("pd_op.matmul", @@ -115,7 +115,8 @@ class MultiHeadMatmulFusePattern // // Constraints. // - src.RequireNativeCall([](const pir::drr::MatchContext &match_ctx) -> bool { + src.RequireNativeCall([](const paddle::drr::MatchContext &match_ctx) + -> bool { const auto &softmax_axis = match_ctx.Attr("softmax_axis"); if (softmax_axis != -1 && softmax_axis != 3) return false; @@ -145,7 +146,7 @@ class MultiHeadMatmulFusePattern // // Result Pattern. // - pir::drr::ResultPattern res = src.ResultPattern(); + paddle::drr::ResultPattern res = src.ResultPattern(); // W combine. const auto &combine_1 = res.Op("builtin.combine"); combine_1({&res.Tensor("matmul_1_in_2"), @@ -153,11 +154,11 @@ class MultiHeadMatmulFusePattern &res.Tensor("matmul_3_in_2")}, {&res.Tensor("combine_1_out")}); const auto &concat_axis = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> int { return 0; }); + [](const paddle::drr::MatchContext &match_ctx) -> int { return 0; }); const auto &concat_1 = res.Op("pd_op.concat", {{"axis", concat_axis}}); res.Tensor("concat_1_out") = concat_1(res.Tensor("combine_1_out")); const auto &reshape_5_shape = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> std::vector { + [](const paddle::drr::MatchContext &match_ctx) -> std::vector { auto matmul_1_in_2 = match_ctx.Tensor("matmul_1_in_2").Shape(); return {-1, 3, matmul_1_in_2.at(1)}; }); @@ -175,7 +176,7 @@ class MultiHeadMatmulFusePattern const auto &concat_2 = res.Op("pd_op.concat", {{"axis", concat_axis}}); res.Tensor("concat_2_out") = concat_2(res.Tensor("combine_2_out")); const auto &reshape_6_shape = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> std::vector { + [](const paddle::drr::MatchContext &match_ctx) -> std::vector { return {3, -1}; }); const auto &reshape_6 = @@ -184,28 +185,31 @@ class MultiHeadMatmulFusePattern {&res.Tensor("reshape_6_out"), &res.NoneTensor()}); const auto &head_number = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> int { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> int { const auto &full_int_array_1_value = match_ctx.Attr>("full_int_array_1_value"); return full_int_array_1_value.at(2); }); const auto &alpha = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> float { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> float { return match_ctx.Attr("full_1_value"); }); - const auto &multihead_matmul = res.Op( - "pd_op.multihead_matmul", - {{"transpose_q", res.Attr([](const pir::drr::MatchContext &match_ctx) { - return false; - })}, - {"transpose_k", res.Attr([](const pir::drr::MatchContext &match_ctx) { - return true; - })}, - {"transpose_v", res.Attr([](const pir::drr::MatchContext &match_ctx) { - return false; - })}, - {"head_number", head_number}, - {"alpha", alpha}}); + const auto &multihead_matmul = + res.Op("pd_op.multihead_matmul", + {{"transpose_q", + res.Attr([](const paddle::drr::MatchContext &match_ctx) { + return false; + })}, + {"transpose_k", + res.Attr([](const paddle::drr::MatchContext &match_ctx) { + return true; + })}, + {"transpose_v", + res.Attr([](const paddle::drr::MatchContext &match_ctx) { + return false; + })}, + {"head_number", head_number}, + {"alpha", alpha}}); multihead_matmul({&res.Tensor("matmul_1_in_1"), &res.Tensor("reshape_5_out"), &res.Tensor("reshape_6_out"), diff --git a/paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.cc b/paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.cc index 86846508a519dc..e86dc04037fa01 100644 --- a/paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.cc @@ -29,10 +29,10 @@ namespace { class Conv2dAddFusePattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern pat = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &conv2d = pat.Op(paddle::dialect::Conv2dOp::name(), {{"strides", pat.Attr("strides")}, @@ -46,7 +46,7 @@ class Conv2dAddFusePattern {&pat.Tensor("conv2d_out")}); pat.Tensor("add_out") = add(pat.Tensor("conv2d_out"), pat.Tensor("bias")); - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &fused_conv2d_add_act = res.Op( paddle::dialect::FusedConv2dAddActOp::name(), @@ -58,21 +58,21 @@ class Conv2dAddFusePattern {"groups", pat.Attr("groups")}, {"data_format", pat.Attr("data_format")}, {"activation", - res.Attr([](const pir::drr::MatchContext &match_ctx) + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::string { return "identity"; })}, {"split_channels", - res.Attr([](const pir::drr::MatchContext &match_ctx) + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::vector { return {}; })}, {"exhaustive_search", - res.Attr([](const pir::drr::MatchContext &match_ctx) -> bool { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { return false; })}, {"workspace_size_MB", - res.Attr([](const pir::drr::MatchContext &match_ctx) -> int { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> int { return 32; })}, {"fuse_alpha", - res.Attr([](const pir::drr::MatchContext &match_ctx) -> float { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> float { return 0.0f; })}, }}); diff --git a/paddle/fluid/pir/transforms/fusion/fc_elementwise_layernorm_fuse_pass.cc b/paddle/fluid/pir/transforms/fusion/fc_elementwise_layernorm_fuse_pass.cc index fdb4621fb350b6..7e5c4bbe8ea187 100644 --- a/paddle/fluid/pir/transforms/fusion/fc_elementwise_layernorm_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fc_elementwise_layernorm_fuse_pass.cc @@ -22,10 +22,10 @@ namespace { class FcElementwiseLayerNormFusePattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern pat = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &fc = pat.Op(paddle::dialect::FcOp::name(), { @@ -47,7 +47,7 @@ class FcElementwiseLayerNormFusePattern &pat.Tensor("layernorm_mean"), &pat.Tensor("layernorm_variance")}); // Constrains the activation is none - pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { int64_t layer_norm_x = 1; for (int i = match_ctx.Attr("begin_norm_axis"); i < match_ctx.Tensor("fc_out").Shape().size(); @@ -60,12 +60,16 @@ class FcElementwiseLayerNormFusePattern return false; }); - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); - const auto &x_num_col_dims_attr = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> std::any { return 1; }); - const auto &false_attr = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> bool { return false; }); + const auto &x_num_col_dims_attr = + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any { + return 1; + }); + const auto &false_attr = + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { + return false; + }); const auto &fused_fc_elementwise_op = res.Op(paddle::dialect::FusedFcElementwiseLayernormOp::name(), @@ -88,10 +92,10 @@ class FcElementwiseLayerNormFusePattern }; class FcElementwiseLayerNormFuse2Pattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern pat = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &fc = pat.Op(paddle::dialect::FcOp::name(), { @@ -113,7 +117,7 @@ class FcElementwiseLayerNormFuse2Pattern &pat.Tensor("layernorm_mean"), &pat.Tensor("layernorm_variance")}); // Constrains the activation is none - pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { int64_t layer_norm_x = 1; for (int i = match_ctx.Attr("begin_norm_axis"); i < match_ctx.Tensor("fc_out").Shape().size(); @@ -126,7 +130,7 @@ class FcElementwiseLayerNormFuse2Pattern return false; }); - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &fused_fc_elementwise_op = res.Op(paddle::dialect::FusedFcElementwiseLayernormOp::name(), diff --git a/paddle/fluid/pir/transforms/fusion/fc_fuse_pass.cc b/paddle/fluid/pir/transforms/fusion/fc_fuse_pass.cc index 2a320b75d6cc31..b49ab9ff4ac77b 100644 --- a/paddle/fluid/pir/transforms/fusion/fc_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fc_fuse_pass.cc @@ -21,10 +21,10 @@ namespace { -class MatmulAddPattern : public pir::drr::DrrPatternBase { +class MatmulAddPattern : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern pat = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &matmul = pat.Op(paddle::dialect::MatmulOp::name(), {{"transpose_x", pat.Attr("transpose_x")}, {"transpose_y", pat.Attr("transpose_y")}}); @@ -32,7 +32,7 @@ class MatmulAddPattern : public pir::drr::DrrPatternBase { matmul({&pat.Tensor("x"), &pat.Tensor("w")}, {&pat.Tensor("matmul_out")}); pat.Tensor("add_out") = add(pat.Tensor("matmul_out"), pat.Tensor("y")); - pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { if (match_ctx.Tensor("w").Shape().size() != 2 || match_ctx.Tensor("x").Shape().size() < 2) { return false; @@ -56,21 +56,23 @@ class MatmulAddPattern : public pir::drr::DrrPatternBase { return false; }); - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &in_num_col_dims_attr = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any { return match_ctx.Tensor("x").Shape().size() - 1; }); - const auto &false_attr = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> bool { return false; }); + const auto &false_attr = + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { + return false; + }); const auto &fc = res.Op(paddle::dialect::FcOp::name(), {{ {"in_num_col_dims", in_num_col_dims_attr}, {"activation_type", - res.Attr([](const pir::drr::MatchContext &match_ctx) + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::string { return ""; })}, {"padding_weights", false_attr}, }}); @@ -79,10 +81,11 @@ class MatmulAddPattern : public pir::drr::DrrPatternBase { } }; -class FcWithReluPattern : public pir::drr::DrrPatternBase { +class FcWithReluPattern + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern pat = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &fc = pat.Op(paddle::dialect::FcOp::name(), {{ @@ -96,18 +99,18 @@ class FcWithReluPattern : public pir::drr::DrrPatternBase { relu({&pat.Tensor("fc_out")}, {&pat.Tensor("relu_out")}); // Constrains the activation is none - pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { return match_ctx.Attr("activation_type").empty(); }); - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &fc_with_relu = res.Op(paddle::dialect::FcOp::name(), {{ {"in_num_col_dims", pat.Attr("in_num_col_dims")}, {"activation_type", - res.Attr([](const pir::drr::MatchContext &match_ctx) + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::string { return "relu"; })}, {"padding_weights", pat.Attr("padding_weights")}, }}); diff --git a/paddle/fluid/pir/transforms/fusion/fc_with_special_op_fuse_pass.cc b/paddle/fluid/pir/transforms/fusion/fc_with_special_op_fuse_pass.cc index 6bb2b3a6d512db..74dd21a0828fe9 100644 --- a/paddle/fluid/pir/transforms/fusion/fc_with_special_op_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fc_with_special_op_fuse_pass.cc @@ -31,10 +31,10 @@ namespace { class SqueezeFcFusePattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern pat = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &squeeze_op = pat.Op(paddle::dialect::SqueezeOp::name()); const auto &matmul = pat.Op(paddle::dialect::MatmulOp::name(), {{"transpose_x", pat.Attr("transpose_x")}, @@ -46,7 +46,7 @@ class SqueezeFcFusePattern {&pat.Tensor("matmul_out")}); pat.Tensor("add_out") = add(pat.Tensor("matmul_out"), pat.Tensor("bias")); // Constrains the activation is none - pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { auto axis_type = match_ctx.Tensor("axis").Dtype().get(); if (axis_type.isa() && axis_type.dyn_cast().size() != 2) { @@ -87,19 +87,23 @@ class SqueezeFcFusePattern return false; }); - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); - const auto &in_num_col_dims_attr = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> std::any { return 1; }); - const auto &false_attr = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> bool { return false; }); + const auto &in_num_col_dims_attr = + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any { + return 1; + }); + const auto &false_attr = + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { + return false; + }); const auto &fc = res.Op(paddle::dialect::FcOp::name(), {{ {"in_num_col_dims", in_num_col_dims_attr}, {"activation_type", - res.Attr([](const pir::drr::MatchContext &match_ctx) + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::string { return ""; })}, {"padding_weights", false_attr}, }}); @@ -109,10 +113,10 @@ class SqueezeFcFusePattern }; class ReshapeFcFusePattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern pat = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &reshape_op = pat.Op(paddle::dialect::ReshapeOp::name()); const auto &matmul = pat.Op(paddle::dialect::MatmulOp::name(), {{"transpose_x", pat.Attr("transpose_x")}, @@ -124,7 +128,7 @@ class ReshapeFcFusePattern {&pat.Tensor("matmul_out")}); add({&pat.Tensor("matmul_out"), &pat.Tensor("bias")}, {&pat.Tensor("add_out")}); - pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { if (match_ctx.Tensor("w").Shape().size() != 2 || match_ctx.Attr("transpose_x") == true || match_ctx.Attr("transpose_y") == true) { @@ -212,10 +216,10 @@ class ReshapeFcFusePattern } return true; }); - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &in_num_col_dims_attr = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any { int i = match_ctx.Tensor("x").Shape().size() - 1; int target = match_ctx.Tensor("reshape_out") @@ -228,15 +232,17 @@ class ReshapeFcFusePattern } return i; }); - const auto &false_attr = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> bool { return false; }); + const auto &false_attr = + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { + return false; + }); const auto &fc = res.Op(paddle::dialect::FcOp::name(), {{ {"in_num_col_dims", in_num_col_dims_attr}, {"activation_type", - res.Attr([](const pir::drr::MatchContext &match_ctx) + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::string { return ""; })}, {"padding_weights", false_attr}, }}); @@ -246,10 +252,10 @@ class ReshapeFcFusePattern }; class FlattenFcFusePattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern pat = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &flatten_op = pat.Op(paddle::dialect::FlattenOp::name(), {{"start_axis", pat.Attr("start_axis")}, {"stop_axis", pat.Attr("stop_axis")}}); @@ -263,7 +269,7 @@ class FlattenFcFusePattern {&pat.Tensor("matmul_out")}); pat.Tensor("add_out") = add(pat.Tensor("matmul_out"), pat.Tensor("bias")); // Constrains the activation is none - pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { bool flatten_flag = false; if (match_ctx.Tensor("x").Shape().size() == 4 && @@ -295,19 +301,23 @@ class FlattenFcFusePattern return false; }); - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); - const auto &in_num_col_dims_attr = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> std::any { return 1; }); - const auto &false_attr = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> bool { return false; }); + const auto &in_num_col_dims_attr = + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any { + return 1; + }); + const auto &false_attr = + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { + return false; + }); const auto &fc = res.Op(paddle::dialect::FcOp::name(), {{ {"in_num_col_dims", in_num_col_dims_attr}, {"activation_type", - res.Attr([](const pir::drr::MatchContext &match_ctx) + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::string { return ""; })}, {"padding_weights", false_attr}, }}); diff --git a/paddle/fluid/pir/transforms/fusion/fused_dot_product_attention_pass.cc b/paddle/fluid/pir/transforms/fusion/fused_dot_product_attention_pass.cc index 639c0e0e4b4140..9b2e7f2f3f2e74 100644 --- a/paddle/fluid/pir/transforms/fusion/fused_dot_product_attention_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fused_dot_product_attention_pass.cc @@ -22,10 +22,10 @@ namespace { class FusedDotProductAttentionPattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern src = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern src = ctx->SourcePattern(); // q[b, s, head, head_dim] -> transpose -> q[b, head, s, head_dim] -> scale const auto &q_transpose = src.Op("pd_op.transpose"); @@ -82,40 +82,45 @@ class FusedDotProductAttentionPattern src.Tensor("out") = o_transpose(src.Tensor("context_matmul_out")); // Constraints - src.RequireNativeCall([](const pir::drr::MatchContext &match_ctx) -> bool { - const auto &softmax_axis = match_ctx.Attr("softmax_axis"); - if (softmax_axis != -1 && softmax_axis != 3) return false; - - bool qk_matmul_transpose_x = - match_ctx.Attr("qk_matmul_transpose_x"); - bool qk_matmul_transpose_y = - match_ctx.Attr("qk_matmul_transpose_y"); - if (qk_matmul_transpose_x || !qk_matmul_transpose_y) return false; - - bool context_matmul_transpose_x = - match_ctx.Attr("context_matmul_transpose_x"); - bool context_matmul_transpose_y = - match_ctx.Attr("context_matmul_transpose_y"); - if (context_matmul_transpose_x || context_matmul_transpose_y) - return false; - - return true; - }); + src.RequireNativeCall( + [](const paddle::drr::MatchContext &match_ctx) -> bool { + const auto &softmax_axis = match_ctx.Attr("softmax_axis"); + if (softmax_axis != -1 && softmax_axis != 3) return false; + + bool qk_matmul_transpose_x = + match_ctx.Attr("qk_matmul_transpose_x"); + bool qk_matmul_transpose_y = + match_ctx.Attr("qk_matmul_transpose_y"); + if (qk_matmul_transpose_x || !qk_matmul_transpose_y) return false; + + bool context_matmul_transpose_x = + match_ctx.Attr("context_matmul_transpose_x"); + bool context_matmul_transpose_y = + match_ctx.Attr("context_matmul_transpose_y"); + if (context_matmul_transpose_x || context_matmul_transpose_y) + return false; + + return true; + }); // Result pattern - pir::drr::ResultPattern res = src.ResultPattern(); + paddle::drr::ResultPattern res = src.ResultPattern(); const auto &scaling_factor = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> float { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> float { return match_ctx.Attr("q_scale_value"); }); const auto &dropout_prob = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> float { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> float { return static_cast(0.0); }); - const auto &is_training = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> bool { return true; }); - const auto &is_causal_masking = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> bool { return false; }); + const auto &is_training = + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { + return true; + }); + const auto &is_causal_masking = + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { + return false; + }); const auto &dot_product_attention = res.Op(paddle::dialect::FusedDotProductAttentionOp::name(), @@ -135,10 +140,10 @@ class FusedDotProductAttentionPattern }; class FusedDotProductAttentionGradPattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern src = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern src = ctx->SourcePattern(); // q[b, s, head, head_dim] -> transpose -> q[b, head, s, head_dim] -> scale const auto &q_transpose = src.Op("pd_op.transpose"); @@ -239,40 +244,45 @@ class FusedDotProductAttentionGradPattern {&src.Tensor("k_grad")}); // Constraints - src.RequireNativeCall([](const pir::drr::MatchContext &match_ctx) -> bool { - const auto &softmax_axis = match_ctx.Attr("softmax_axis"); - if (softmax_axis != -1 && softmax_axis != 3) return false; - - bool qk_matmul_transpose_x = - match_ctx.Attr("qk_matmul_transpose_x"); - bool qk_matmul_transpose_y = - match_ctx.Attr("qk_matmul_transpose_y"); - if (qk_matmul_transpose_x || !qk_matmul_transpose_y) return false; - - bool context_matmul_transpose_x = - match_ctx.Attr("context_matmul_transpose_x"); - bool context_matmul_transpose_y = - match_ctx.Attr("context_matmul_transpose_y"); - if (context_matmul_transpose_x || context_matmul_transpose_y) - return false; - - return true; - }); + src.RequireNativeCall( + [](const paddle::drr::MatchContext &match_ctx) -> bool { + const auto &softmax_axis = match_ctx.Attr("softmax_axis"); + if (softmax_axis != -1 && softmax_axis != 3) return false; + + bool qk_matmul_transpose_x = + match_ctx.Attr("qk_matmul_transpose_x"); + bool qk_matmul_transpose_y = + match_ctx.Attr("qk_matmul_transpose_y"); + if (qk_matmul_transpose_x || !qk_matmul_transpose_y) return false; + + bool context_matmul_transpose_x = + match_ctx.Attr("context_matmul_transpose_x"); + bool context_matmul_transpose_y = + match_ctx.Attr("context_matmul_transpose_y"); + if (context_matmul_transpose_x || context_matmul_transpose_y) + return false; + + return true; + }); // Result pattern - pir::drr::ResultPattern res = src.ResultPattern(); + paddle::drr::ResultPattern res = src.ResultPattern(); const auto &scaling_factor = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> float { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> float { return match_ctx.Attr("q_scale_value"); }); const auto &dropout_prob = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> float { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> float { return static_cast(0.0); }); - const auto &is_training = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> bool { return true; }); - const auto &is_causal_masking = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> bool { return false; }); + const auto &is_training = + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { + return true; + }); + const auto &is_causal_masking = + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { + return false; + }); const auto &dot_product_attention = res.Op(paddle::dialect::FusedDotProductAttentionOp::name(), @@ -307,11 +317,11 @@ class FusedDotProductAttentionGradPattern }; class FusedDotProductAttentionWithDropoutPattern - : public pir::drr::DrrPatternBase< + : public paddle::drr::DrrPatternBase< FusedDotProductAttentionWithDropoutPattern> { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern src = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern src = ctx->SourcePattern(); // q[b, s, head, head_dim] -> transpose -> q[b, head, s, head_dim] -> scale const auto &q_transpose = src.Op("pd_op.transpose"); @@ -376,40 +386,45 @@ class FusedDotProductAttentionWithDropoutPattern src.Tensor("out") = o_transpose(src.Tensor("context_matmul_out")); // Constraints - src.RequireNativeCall([](const pir::drr::MatchContext &match_ctx) -> bool { - const auto &softmax_axis = match_ctx.Attr("softmax_axis"); - if (softmax_axis != -1 && softmax_axis != 3) return false; - - bool qk_matmul_transpose_x = - match_ctx.Attr("qk_matmul_transpose_x"); - bool qk_matmul_transpose_y = - match_ctx.Attr("qk_matmul_transpose_y"); - if (qk_matmul_transpose_x || !qk_matmul_transpose_y) return false; - - bool context_matmul_transpose_x = - match_ctx.Attr("context_matmul_transpose_x"); - bool context_matmul_transpose_y = - match_ctx.Attr("context_matmul_transpose_y"); - if (context_matmul_transpose_x || context_matmul_transpose_y) - return false; - - return true; - }); + src.RequireNativeCall( + [](const paddle::drr::MatchContext &match_ctx) -> bool { + const auto &softmax_axis = match_ctx.Attr("softmax_axis"); + if (softmax_axis != -1 && softmax_axis != 3) return false; + + bool qk_matmul_transpose_x = + match_ctx.Attr("qk_matmul_transpose_x"); + bool qk_matmul_transpose_y = + match_ctx.Attr("qk_matmul_transpose_y"); + if (qk_matmul_transpose_x || !qk_matmul_transpose_y) return false; + + bool context_matmul_transpose_x = + match_ctx.Attr("context_matmul_transpose_x"); + bool context_matmul_transpose_y = + match_ctx.Attr("context_matmul_transpose_y"); + if (context_matmul_transpose_x || context_matmul_transpose_y) + return false; + + return true; + }); // Result pattern - pir::drr::ResultPattern res = src.ResultPattern(); + paddle::drr::ResultPattern res = src.ResultPattern(); const auto &scaling_factor = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> float { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> float { return match_ctx.Attr("q_scale_value"); }); const auto &dropout_prob = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> float { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> float { return static_cast(0.0); }); - const auto &is_training = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> bool { return true; }); - const auto &is_causal_masking = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> bool { return false; }); + const auto &is_training = + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { + return true; + }); + const auto &is_causal_masking = + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { + return false; + }); const auto &dot_product_attention = res.Op(paddle::dialect::FusedDotProductAttentionOp::name(), @@ -429,11 +444,11 @@ class FusedDotProductAttentionWithDropoutPattern }; class FusedDotProductAttentionGradWithDropoutPattern - : public pir::drr::DrrPatternBase< + : public paddle::drr::DrrPatternBase< FusedDotProductAttentionGradWithDropoutPattern> { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern src = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern src = ctx->SourcePattern(); // q[b, s, head, head_dim] -> transpose -> q[b, head, s, head_dim] -> scale const auto &q_transpose = src.Op("pd_op.transpose"); @@ -548,36 +563,41 @@ class FusedDotProductAttentionGradWithDropoutPattern {&src.Tensor("k_grad")}); // Constraints - src.RequireNativeCall([](const pir::drr::MatchContext &match_ctx) -> bool { - const auto &softmax_axis = match_ctx.Attr("softmax_axis"); - if (softmax_axis != -1 && softmax_axis != 3) return false; - - bool qk_matmul_transpose_x = - match_ctx.Attr("qk_matmul_transpose_x"); - bool qk_matmul_transpose_y = - match_ctx.Attr("qk_matmul_transpose_y"); - if (qk_matmul_transpose_x || !qk_matmul_transpose_y) return false; - - bool context_matmul_transpose_x = - match_ctx.Attr("context_matmul_transpose_x"); - bool context_matmul_transpose_y = - match_ctx.Attr("context_matmul_transpose_y"); - if (context_matmul_transpose_x || context_matmul_transpose_y) - return false; - - return true; - }); + src.RequireNativeCall( + [](const paddle::drr::MatchContext &match_ctx) -> bool { + const auto &softmax_axis = match_ctx.Attr("softmax_axis"); + if (softmax_axis != -1 && softmax_axis != 3) return false; + + bool qk_matmul_transpose_x = + match_ctx.Attr("qk_matmul_transpose_x"); + bool qk_matmul_transpose_y = + match_ctx.Attr("qk_matmul_transpose_y"); + if (qk_matmul_transpose_x || !qk_matmul_transpose_y) return false; + + bool context_matmul_transpose_x = + match_ctx.Attr("context_matmul_transpose_x"); + bool context_matmul_transpose_y = + match_ctx.Attr("context_matmul_transpose_y"); + if (context_matmul_transpose_x || context_matmul_transpose_y) + return false; + + return true; + }); // Result pattern - pir::drr::ResultPattern res = src.ResultPattern(); + paddle::drr::ResultPattern res = src.ResultPattern(); const auto &scaling_factor = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> float { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> float { return match_ctx.Attr("q_scale_value"); }); - const auto &is_training = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> bool { return true; }); - const auto &is_causal_masking = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> bool { return false; }); + const auto &is_training = + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { + return true; + }); + const auto &is_causal_masking = + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { + return false; + }); const auto &dot_product_attention = res.Op(paddle::dialect::FusedDotProductAttentionOp::name(), diff --git a/paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.cc b/paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.cc index 35079d4f2cf1ca..df8b39cfc8676d 100644 --- a/paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.cc @@ -22,10 +22,10 @@ namespace { class FusedDropoutAddPattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern pat = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &dropout = pat.Op(paddle::dialect::DropoutOp::name(), {{"p", pat.Attr("p")}, {"is_test", pat.Attr("is_test")}, @@ -38,7 +38,7 @@ class FusedDropoutAddPattern {&pat.Tensor("dropout_out"), &pat.Tensor("mask")}); pat.Tensor("add_out") = add(pat.Tensor("dropout_out"), pat.Tensor("y")); - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &fused_dropout_add = res.Op(paddle::dialect::FusedDropoutAddOp::name(), {{{"p", pat.Attr("p")}, @@ -53,10 +53,10 @@ class FusedDropoutAddPattern }; class FusedDropoutGradAddGradPattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern pat = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &dropout = pat.Op(paddle::dialect::DropoutOp::name(), {{"p", pat.Attr("p")}, {"is_test", pat.Attr("is_test")}, @@ -81,7 +81,7 @@ class FusedDropoutGradAddGradPattern dropout_grad({&pat.Tensor("mask"), &pat.Tensor("dropout_out_grad")}, {&pat.Tensor("x_grad")}); - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &fused_dropout_add = res.Op(paddle::dialect::FusedDropoutAddOp::name(), {{{"p", pat.Attr("p")}, diff --git a/paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.cc b/paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.cc index 6bc15234efd31b..02a6b4744cdcb8 100644 --- a/paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.cc @@ -21,10 +21,11 @@ namespace { -class FusedLinearPattern : public pir::drr::DrrPatternBase { +class FusedLinearPattern + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern pat = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &matmul = pat.Op(paddle::dialect::MatmulOp::name(), {{"transpose_x", pat.Attr("trans_x")}, {"transpose_y", pat.Attr("trans_y")}}); @@ -33,15 +34,15 @@ class FusedLinearPattern : public pir::drr::DrrPatternBase { pat.Tensor("tmp") = matmul(pat.Tensor("x"), pat.Tensor("w")); pat.Tensor("out") = add(pat.Tensor("tmp"), pat.Tensor("bias")); - pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { return (match_ctx.Tensor("w").Shape().size() == 2 && match_ctx.Tensor("x").Shape().size() >= 2 && match_ctx.Tensor("bias").Shape().size() == 1); }); - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &act_attr = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any { return "none"; }); const auto &fused_gemm_epilogue = @@ -56,10 +57,10 @@ class FusedLinearPattern : public pir::drr::DrrPatternBase { }; class FusedLinearGradPattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern pat = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &matmul = pat.Op(paddle::dialect::MatmulOp::name(), {{"transpose_x", pat.Attr("trans_x")}, {"transpose_y", pat.Attr("trans_y")}}); @@ -76,15 +77,15 @@ class FusedLinearGradPattern matmul_grad({&pat.Tensor("x"), &pat.Tensor("w"), &pat.Tensor("tmp_grad")}, {&pat.Tensor("x_grad"), &pat.Tensor("w_grad")}); - pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { return (match_ctx.Tensor("w").Shape().size() == 2 && match_ctx.Tensor("x").Shape().size() >= 2 && match_ctx.Tensor("bias").Shape().size() == 1); }); - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &act_attr = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any { return "none"; }); const auto &fused_gemm_epilogue = @@ -111,10 +112,10 @@ class FusedLinearGradPattern }; class FusedLinearGeluPattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern pat = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); // Source pattern const auto &fused_gemm_epilogue = pat.Op(paddle::dialect::FusedGemmEpilogueOp::name(), @@ -128,14 +129,14 @@ class FusedLinearGeluPattern pat.Tensor("out") = gelu(pat.Tensor("fuse_out")); // Constrains the activation is none - pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { return (match_ctx.Attr("act") == "none"); }); // Result pattern - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &act_attr = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any { return "gelu"; }); const auto &fused_gemm_epilogue_gelu = @@ -149,10 +150,10 @@ class FusedLinearGeluPattern } }; class FusedLinearReluPattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern pat = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); // Source pattern const auto &fused_gemm_epilogue = pat.Op(paddle::dialect::FusedGemmEpilogueOp::name(), @@ -166,14 +167,14 @@ class FusedLinearReluPattern pat.Tensor("out") = relu(pat.Tensor("fuse_out")); // Constrains the activation is none - pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { return (match_ctx.Attr("act") == "none"); }); // Result pattern - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &act_attr = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any { return "relu"; }); const auto &fused_gemm_epilogue_relu = @@ -188,10 +189,10 @@ class FusedLinearReluPattern }; class FusedLinearGeluGradPattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern pat = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &fused_gemm_epilogue = pat.Op(paddle::dialect::FusedGemmEpilogueOp::name(), {{{"trans_x", pat.Attr("trans_x1")}, @@ -218,14 +219,14 @@ class FusedLinearGeluGradPattern pat.Tensor("gelu_dx") = pat.Op(paddle::dialect::GeluGradOp::name())( pat.Tensor("fuse_out"), pat.Tensor("x1_grad")); - pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { return match_ctx.Attr("act1") == "none" && match_ctx.Attr("act2") == "none"; }); - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &act_attr = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any { return "gelu"; }); const auto &fused_gemm_epilogue_new = @@ -234,7 +235,7 @@ class FusedLinearGeluGradPattern {"trans_y", pat.Attr("trans_y1")}, {"activation", act_attr}}}); const auto &act_grad_attr = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any { return "gelu_grad"; }); const auto &fused_gemm_epilogue_grad_new = @@ -256,10 +257,10 @@ class FusedLinearGeluGradPattern }; class FusedLinearReluGradPattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern pat = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &fused_gemm_epilogue = pat.Op(paddle::dialect::FusedGemmEpilogueOp::name(), {{{"trans_x", pat.Attr("trans_x1")}, @@ -297,14 +298,14 @@ class FusedLinearReluGradPattern &pat.Tensor("w_grad"), &pat.Tensor("bias_grad")}); - pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { return match_ctx.Attr("act1") == "relu" && match_ctx.Attr("act3") == "none"; }); - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &act_grad_attr = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any { return "relu_grad"; }); const auto &res_fused_gemm_epilogue_grad1 = diff --git a/paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.cc b/paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.cc index 7a3afec65f33fc..8c93ff98226754 100644 --- a/paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.cc @@ -22,10 +22,10 @@ namespace { // add_grad + matmul_grad + add_ -> matmul + fused_liner_param_gard_add class FusedMatmulAddGradAddPattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern pat = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &matmul0 = pat.Op(paddle::dialect::MatmulOp::name(), {{"transpose_x", pat.Attr("trans_x")}, {"transpose_y", pat.Attr("trans_y")}}); @@ -48,7 +48,7 @@ class FusedMatmulAddGradAddPattern pat.Tensor("add_out") = add_(pat.Tensor("dweight"), pat.Tensor("weight_grad")); - pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { const auto &x_trans = match_ctx.Attr("trans_x"); const auto &y_trans = match_ctx.Attr("trans_y"); return (match_ctx.Tensor("weight_grad").Shape() == @@ -58,17 +58,21 @@ class FusedMatmulAddGradAddPattern x_trans == false && y_trans == false); }); - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &muti_precision_attr = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> bool { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { return !(match_ctx.Tensor("dweight").Dtype() == match_ctx.Tensor("weight_grad").Dtype()); }); - const auto &true_attr = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> bool { return true; }); - const auto &false_attr = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> bool { return false; }); + const auto &true_attr = + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { + return true; + }); + const auto &false_attr = + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { + return false; + }); const auto &matmul = res.Op(paddle::dialect::MatmulOp::name(), @@ -89,10 +93,10 @@ class FusedMatmulAddGradAddPattern // matmul_grad + add_ -> matmul + fused_liner_param_gard_add class FusedMatmulGradAddPattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern pat = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &matmul_grad = pat.Op(paddle::dialect::MatmulGradOp::name(), {{"transpose_x", pat.Attr("trans_x")}, {"transpose_y", pat.Attr("trans_y")}}); @@ -104,7 +108,7 @@ class FusedMatmulGradAddPattern pat.Tensor("add_out") = add_(pat.Tensor("dweight"), pat.Tensor("weight_grad")); - pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { const auto &x_trans = match_ctx.Attr("trans_x"); const auto &y_trans = match_ctx.Attr("trans_y"); return (match_ctx.Tensor("weight_grad").Shape() == @@ -112,18 +116,22 @@ class FusedMatmulGradAddPattern x_trans == false && y_trans == false); }); - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &muti_precision_attr = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> bool { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { return !(match_ctx.Tensor("dweight").Dtype() == match_ctx.Tensor("weight_grad").Dtype()); }); - const auto &true_attr = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> bool { return true; }); - const auto &false_attr = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> bool { return false; }); + const auto &true_attr = + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { + return true; + }); + const auto &false_attr = + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { + return false; + }); const auto &matmul = res.Op(paddle::dialect::MatmulOp::name(), @@ -145,10 +153,10 @@ class FusedMatmulGradAddPattern // matmul + 0 = add_(0,1) -> fused_liner_param_gard_add class FusedMatmulAddaPattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern pat = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &matmul = pat.Op(paddle::dialect::MatmulOp::name(), {{"transpose_x", pat.Attr("trans_x")}, {"transpose_y", pat.Attr("trans_y")}}); @@ -159,22 +167,26 @@ class FusedMatmulAddaPattern pat.Tensor("add_out") = add_(pat.Tensor("dweight"), pat.Tensor("weight_grad")); - pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { return (match_ctx.Tensor("weight_grad").Shape() == match_ctx.Tensor("dweight").Shape()); }); - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &muti_precision_attr = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> bool { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { return !(match_ctx.Tensor("dweight").Dtype() == match_ctx.Tensor("weight_grad").Dtype()); }); - const auto &true_attr = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> bool { return true; }); - const auto &false_attr = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> bool { return false; }); + const auto &true_attr = + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { + return true; + }); + const auto &false_attr = + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { + return false; + }); const auto &fused_linear_param_grad_add = res.Op( paddle::dialect::FusedLinearParamGradAddOp::name(), @@ -190,10 +202,10 @@ class FusedMatmulAddaPattern // matmul + 1 = add_(1,0) -> fused_liner_param_gard_add class FusedMatmulAddbPattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern pat = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &matmul = pat.Op(paddle::dialect::MatmulOp::name(), {{"transpose_x", pat.Attr("trans_x")}, {"transpose_y", pat.Attr("trans_y")}}); @@ -204,22 +216,26 @@ class FusedMatmulAddbPattern pat.Tensor("add_out") = add_(pat.Tensor("weight_grad"), pat.Tensor("dweight")); - pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { return (match_ctx.Tensor("weight_grad").Shape() == match_ctx.Tensor("dweight").Shape()); }); - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &muti_precision_attr = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> bool { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { return !(match_ctx.Tensor("dweight").Dtype() == match_ctx.Tensor("weight_grad").Dtype()); }); - const auto &true_attr = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> bool { return true; }); - const auto &false_attr = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> bool { return false; }); + const auto &true_attr = + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { + return true; + }); + const auto &false_attr = + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { + return false; + }); const auto &fused_linear_param_grad_add = res.Op( paddle::dialect::FusedLinearParamGradAddOp::name(), @@ -235,10 +251,10 @@ class FusedMatmulAddbPattern // add_grad + matmul + 0 = add_(0,1) -> fused_liner_param_gard_add class FusedMatmulAddGradAddaPattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern pat = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &matmul = pat.Op(paddle::dialect::MatmulOp::name(), {{"transpose_x", pat.Attr("trans_x")}, {"transpose_y", pat.Attr("trans_y")}}); @@ -261,21 +277,23 @@ class FusedMatmulAddGradAddaPattern pat.Tensor("dweight_out") = add_(pat.Tensor("dweight"), pat.Tensor("weight_grad")); - pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { return (match_ctx.Tensor("weight_grad").Shape() == match_ctx.Tensor("dweight").Shape() && match_ctx.Tensor("out").Shape() == match_ctx.Tensor("dadd_out").Shape()); }); - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &muti_precision_attr = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> bool { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { return !(match_ctx.Tensor("dweight").Dtype() == match_ctx.Tensor("weight_grad").Dtype()); }); - const auto &true_attr = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> bool { return true; }); + const auto &true_attr = + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { + return true; + }); const auto &fused_linear_param_grad_add = res.Op( paddle::dialect::FusedLinearParamGradAddOp::name(), {{{"multi_precision", muti_precision_attr}, {"has_bias", true_attr}}}); @@ -290,10 +308,10 @@ class FusedMatmulAddGradAddaPattern // add_grad + matmul + 1 = add_(1,0) -> fused_liner_param_gard_add class FusedMatmulAddGradAddbPattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern pat = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &matmul = pat.Op(paddle::dialect::MatmulOp::name(), {{"transpose_x", pat.Attr("trans_x")}, {"transpose_y", pat.Attr("trans_y")}}); @@ -316,21 +334,23 @@ class FusedMatmulAddGradAddbPattern pat.Tensor("dweight_out") = add_(pat.Tensor("weight_grad"), pat.Tensor("dweight")); - pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { return (match_ctx.Tensor("weight_grad").Shape() == match_ctx.Tensor("dweight").Shape() && match_ctx.Tensor("out").Shape() == match_ctx.Tensor("dadd_out").Shape()); }); - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &muti_precision_attr = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> bool { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { return !(match_ctx.Tensor("dweight").Dtype() == match_ctx.Tensor("weight_grad").Dtype()); }); - const auto &true_attr = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> bool { return true; }); + const auto &true_attr = + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { + return true; + }); const auto &fused_linear_param_grad_add = res.Op( paddle::dialect::FusedLinearParamGradAddOp::name(), {{{"multi_precision", muti_precision_attr}, {"has_bias", true_attr}}}); diff --git a/paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.cc b/paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.cc index fa83418e562baf..82864f3d80e88f 100644 --- a/paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.cc @@ -36,13 +36,13 @@ int getSMVersion() { } class FusedWeightOnlyLinearPattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { + void operator()(paddle::drr::DrrPatternContext *ctx) const override { // // Source Pattern. // - pir::drr::SourcePattern src = ctx->SourcePattern(); + paddle::drr::SourcePattern src = ctx->SourcePattern(); const auto &matmul = src.Op(paddle::dialect::MatmulOp::name(), {{"transpose_x", src.Attr("matmul_transpose_x")}, @@ -57,47 +57,49 @@ class FusedWeightOnlyLinearPattern // // Constraints. // - src.RequireNativeCall([](const pir::drr::MatchContext &match_ctx) -> bool { - bool matmul_trans_x = match_ctx.Attr("matmul_transpose_x"); - bool matmul_trans_y = match_ctx.Attr("matmul_transpose_y"); - if (matmul_trans_x || matmul_trans_y) return false; - - if (!(match_ctx.Tensor("w").Shape().size() == 2 && - match_ctx.Tensor("x").Shape().size() >= 2 && - match_ctx.Tensor("bias").Shape().size() == 1)) { - return false; - } - - auto w_dims = match_ctx.Tensor("w").Shape(); - if (w_dims.at(0) % 64 != 0 || w_dims.at(1) % 16 != 0) return false; - - auto w_dtype = match_ctx.Tensor("w").Dtype().get(); - if (!w_dtype.isa() && !w_dtype.isa()) - return false; - - auto x_dims = match_ctx.Tensor("x").Shape(); - if (x_dims.at(x_dims.size() - 1) != w_dims.at(1)) return false; - - return true; - }); + src.RequireNativeCall( + [](const paddle::drr::MatchContext &match_ctx) -> bool { + bool matmul_trans_x = match_ctx.Attr("matmul_transpose_x"); + bool matmul_trans_y = match_ctx.Attr("matmul_transpose_y"); + if (matmul_trans_x || matmul_trans_y) return false; + + if (!(match_ctx.Tensor("w").Shape().size() == 2 && + match_ctx.Tensor("x").Shape().size() >= 2 && + match_ctx.Tensor("bias").Shape().size() == 1)) { + return false; + } + + auto w_dims = match_ctx.Tensor("w").Shape(); + if (w_dims.at(0) % 64 != 0 || w_dims.at(1) % 16 != 0) return false; + + auto w_dtype = match_ctx.Tensor("w").Dtype().get(); + if (!w_dtype.isa() && + !w_dtype.isa()) + return false; + + auto x_dims = match_ctx.Tensor("x").Shape(); + if (x_dims.at(x_dims.size() - 1) != w_dims.at(1)) return false; + + return true; + }); // // Result Pattern. // - pir::drr::ResultPattern res = src.ResultPattern(); + paddle::drr::ResultPattern res = src.ResultPattern(); // quantize weight const auto &weight_only_int8_attr = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any { return "weight_only_int8"; }); const auto &arch_attr = - res.Attr([&](const pir::drr::MatchContext &match_ctx) -> int { + res.Attr([&](const paddle::drr::MatchContext &match_ctx) -> int { return getSMVersion(); }); const auto &group_size_attr = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> int { return -1; }); + [](const paddle::drr::MatchContext &match_ctx) -> int { return -1; }); const auto &weight_quantize = res.Op(paddle::dialect::WeightQuantizeOp::name(), @@ -109,7 +111,7 @@ class FusedWeightOnlyLinearPattern &res.Tensor("weight_scale_tensor")}); const auto &weight_dtype_attr = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any { return "int8"; }); diff --git a/paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.cc b/paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.cc index 627c1cd516cc85..0bced0b8ec823f 100644 --- a/paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.cc @@ -28,10 +28,10 @@ namespace { class MatmulScaleFusePattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern pat = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &matmul_op = pat.Op(paddle::dialect::MatmulOp::name(), {{"transpose_x", pat.Attr("transpose_x")}, {"transpose_y", pat.Attr("transpose_y")}}); @@ -50,23 +50,23 @@ class MatmulScaleFusePattern scale_op({&pat.Tensor("matmul_out"), &full_op()}, {&pat.Tensor("scale_out")}); - pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { return std::abs(match_ctx.Attr("bias")) <= 1e-6; }); - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &full_op_res = res.Op(paddle::dialect::FullOp::name(), {{"shape", pat.Attr("shape")}, {"value", pat.Attr("value")}, {"dtype", pat.Attr("dtype")}, {"place", pat.Attr("place")}}); - const auto &scale_op_res = - res.Op(paddle::dialect::ScaleOp::name(), - {{"bias", - res.Attr([](const pir::drr::MatchContext &match_ctx) -> float { - return 0.0; - })}, - {"bias_after_scale", pat.Attr("bias_after_scale")}}); + const auto &scale_op_res = res.Op( + paddle::dialect::ScaleOp::name(), + {{"bias", + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> float { + return 0.0; + })}, + {"bias_after_scale", pat.Attr("bias_after_scale")}}); const auto &matmul_op_res = res.Op(paddle::dialect::MatmulOp::name(), {{"transpose_x", pat.Attr("transpose_x")}, diff --git a/paddle/fluid/pir/transforms/identity_op_clean_pass.cc b/paddle/fluid/pir/transforms/identity_op_clean_pass.cc index 377610196bf963..ac49d494d1c731 100644 --- a/paddle/fluid/pir/transforms/identity_op_clean_pass.cc +++ b/paddle/fluid/pir/transforms/identity_op_clean_pass.cc @@ -32,10 +32,10 @@ namespace { class RemoveUselessScalePattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern pat = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &full_op = pat.Op(paddle::dialect::FullOp::name(), {{"shape", pat.Attr("shape")}, {"value", pat.Attr("value")}, @@ -47,21 +47,21 @@ class RemoveUselessScalePattern {"bias_after_scale", pat.Attr("bias_after_scale")}}); scale_op({&pat.Tensor("x"), &full_op()}, {&pat.Tensor("scale_out")}); - pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { return (match_ctx.Attr("value") == 1.0 && match_ctx.Attr("bias") == 0.0); }); - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); res.Tensor("scale_out").Assign(res.Tensor("x")); } }; class RemoveRedundentScalePattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern pat = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &full_op_1 = pat.Op(paddle::dialect::FullOp::name(), {{"shape", pat.Attr("shape_1")}, {"value", pat.Attr("value_1")}, @@ -84,10 +84,10 @@ class RemoveRedundentScalePattern scale_op_2({&pat.Tensor("scale_1_out"), &full_op_2()}, {&pat.Tensor("scale_2_out")}); - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &bais_res = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> float { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> float { float res_bias_1 = 0.f; float res_bias_2 = 0.f; if (match_ctx.Attr("bias_after_scale_1")) { @@ -106,7 +106,7 @@ class RemoveRedundentScalePattern return res_bias_2; }); const auto &res_scale_input = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> float { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> float { return match_ctx.Attr("value_1") * match_ctx.Attr("value_2"); }); @@ -116,22 +116,22 @@ class RemoveRedundentScalePattern {"value", res_scale_input}, {"dtype", pat.Attr("dtype_1")}, {"place", pat.Attr("place_1")}}); - const auto &scale_op_res = - res.Op("pd_op.scale", - {{"bias", bais_res}, - {"bias_after_scale", - res.Attr([](const pir::drr::MatchContext &match_ctx) -> bool { - return true; - })}}); + const auto &scale_op_res = res.Op( + "pd_op.scale", + {{"bias", bais_res}, + {"bias_after_scale", + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { + return true; + })}}); scale_op_res({&res.Tensor("x"), &full_op_res()}, {&res.Tensor("scale_2_out")}); } }; class RemoveUselessCastPattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { + void operator()(paddle::drr::DrrPatternContext *ctx) const override { auto pat = ctx->SourcePattern(); pat.Tensor("ret") = pat.Op("pd_op.cast")(pat.Tensor("arg0")); pat.RequireEqual(pat.Tensor("ret").dtype(), pat.Tensor("arg0").dtype()); @@ -141,16 +141,16 @@ class RemoveUselessCastPattern }; class RemoveUselessConcatPattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { + void operator()(paddle::drr::DrrPatternContext *ctx) const override { auto pat = ctx->SourcePattern(); const auto &combine = pat.Op(pir::CombineOp::name()); combine({&pat.Tensor("x")}, {&pat.Tensor("combine_out")}); pat.Tensor("out") = pat.Op(paddle::dialect::ConcatOp::name())( pat.Tensor("combine_out"), pat.Tensor("axis")); - pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { - auto combine_out = dynamic_cast( + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { + auto combine_out = dynamic_cast( match_ctx.Tensor("combine_out")); return combine_out.type_isa() && combine_out.type_dyn_cast().size() == 1; @@ -161,8 +161,8 @@ class RemoveUselessConcatPattern }; class RemoveRedundentCastPattern - : public pir::drr::DrrPatternBase { - void operator()(pir::drr::DrrPatternContext *ctx) const override { + : public paddle::drr::DrrPatternBase { + void operator()(paddle::drr::DrrPatternContext *ctx) const override { auto pat = ctx->SourcePattern(); pat.Tensor("tmp") = pat.Op( "pd_op.cast", {{"dtype", pat.Attr("dtype1")}})(pat.Tensor("arg0")); @@ -175,10 +175,10 @@ class RemoveRedundentCastPattern }; class RemoveRedundentTransposePattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern pat = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &transpose1 = pat.Op("pd_op.transpose", {{"perm", pat.Attr("perm_1")}}); const auto &transpose2 = @@ -186,9 +186,9 @@ class RemoveRedundentTransposePattern pat.Tensor("ret") = transpose2(transpose1(pat.Tensor("arg_transpose"))); - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &new_perm_attr = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> std::vector { + [](const paddle::drr::MatchContext &match_ctx) -> std::vector { const auto &perm1 = match_ctx.Attr>("perm_1"); const auto &perm2 = match_ctx.Attr>("perm_2"); std::vector new_perm; diff --git a/test/cpp/pir/pattern_rewrite/drr_same_type_binding_test.cc b/test/cpp/pir/pattern_rewrite/drr_same_type_binding_test.cc index b550212ad3654e..1a938e7f600b78 100644 --- a/test/cpp/pir/pattern_rewrite/drr_same_type_binding_test.cc +++ b/test/cpp/pir/pattern_rewrite/drr_same_type_binding_test.cc @@ -53,10 +53,10 @@ class SameTypeBindingTestPattern // This class is for test cases of the same type of OP. // (without considering the computational logic between OPs, // only focusing on the process of matching and replacing) - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern src = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern src = ctx->SourcePattern(); // path 1 const auto &transpose_1 = @@ -141,7 +141,7 @@ class SameTypeBindingTestPattern const auto &relu_2 = src.Op("pd_op.relu"); src.Tensor("output6") = relu_2(src.Tensor("add_2_out")); - pir::drr::ResultPattern res = src.ResultPattern(); + paddle::drr::ResultPattern res = src.ResultPattern(); const auto &transpose_7 = res.Op("pd_op.transpose", {{"perm", src.Attr("perm_4")}}); res.Tensor("output0") = transpose_7(res.Tensor("input_1")); diff --git a/test/cpp/pir/pattern_rewrite/drr_test.cc b/test/cpp/pir/pattern_rewrite/drr_test.cc index fc0e7ae94f05f9..54b5ff2025e49d 100644 --- a/test/cpp/pir/pattern_rewrite/drr_test.cc +++ b/test/cpp/pir/pattern_rewrite/drr_test.cc @@ -24,11 +24,11 @@ #include "paddle/pir/pass/pass_manager.h" class RemoveRedundentReshapePattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { + void operator()(paddle::drr::DrrPatternContext *ctx) const override { // Source patterns - pir::drr::SourcePattern pat = ctx->SourcePattern(); + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &reshape1 = pat.Op("pd_op.reshape"); const auto &reshape2 = pat.Op("pd_op.reshape"); @@ -38,18 +38,18 @@ class RemoveRedundentReshapePattern {&pat.Tensor("ret"), &pat.Tensor("xshape_1")}); // Result patterns - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); res.Op("pd_op.reshape")({&res.Tensor("arg0"), &res.Tensor("shape1")}, {&res.Tensor("ret"), &res.Tensor("xshape_1")}); } }; class FoldExpandToConstantPattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { + void operator()(paddle::drr::DrrPatternContext *ctx) const override { // Source Pattern - pir::drr::SourcePattern pat = ctx->SourcePattern(); + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &full1 = pat.Op("pd_op.full", {{"shape", pat.Attr("shape_1")}, {"value", pat.Attr("value_1")}, @@ -64,9 +64,9 @@ class FoldExpandToConstantPattern pat.Tensor("ret") = expand(full1(), full_int_array1()); // Result patterns - pir::drr::ResultPattern res = pat.ResultPattern(); - const auto &new_perm_attr = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> phi::IntArray { + paddle::drr::ResultPattern res = pat.ResultPattern(); + const auto &new_perm_attr = res.Attr( + [](const paddle::drr::MatchContext &match_ctx) -> phi::IntArray { auto shape = match_ctx.Attr>("expand_shape_value"); @@ -82,10 +82,10 @@ class FoldExpandToConstantPattern }; class RemoveRedundentTransposePattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern pat = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &transpose1 = pat.Op("pd_op.transpose", {{"perm", pat.Attr("perm_1")}}); const auto &transpose2 = @@ -93,9 +93,9 @@ class RemoveRedundentTransposePattern pat.Tensor("ret") = transpose2(transpose1(pat.Tensor("arg_transpose"))); - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &new_perm_attr = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> std::vector { + [](const paddle::drr::MatchContext &match_ctx) -> std::vector { const auto &perm1 = match_ctx.Attr>("perm_1"); const auto &perm2 = match_ctx.Attr>("perm_2"); std::vector new_perm; @@ -112,8 +112,8 @@ class RemoveRedundentTransposePattern }; class RemoveRedundentCastPattern - : public pir::drr::DrrPatternBase { - void operator()(pir::drr::DrrPatternContext *ctx) const override { + : public paddle::drr::DrrPatternBase { + void operator()(paddle::drr::DrrPatternContext *ctx) const override { auto pat = ctx->SourcePattern(); pat.Tensor("tmp") = pat.Op( "pd_op.cast", {{"dtype", pat.Attr("dtype1")}})(pat.Tensor("arg0")); @@ -126,9 +126,9 @@ class RemoveRedundentCastPattern }; class RemoveUselessCastPattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { + void operator()(paddle::drr::DrrPatternContext *ctx) const override { auto pat = ctx->SourcePattern(); pat.Tensor("ret") = pat.Op("pd_op.cast")(pat.Tensor("arg0")); pat.RequireEqual(pat.Tensor("ret").dtype(), pat.Tensor("arg0").dtype());