diff --git a/paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.cc index a887b035852a3b..185800c623ffcc 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.cc @@ -20,7 +20,6 @@ #include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" -#include "paddle/fluid/pir/drr/api/match_context.h" #include "paddle/pir/core/builtin_dialect.h" #include "paddle/pir/pass/pass.h" #include "paddle/pir/pattern_rewrite/pattern_applicator.h" diff --git a/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc index 439bba580a6c0d..8494171c676616 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc @@ -25,7 +25,6 @@ #include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" -#include "paddle/fluid/pir/drr/api/match_context.h" #include "paddle/pir/core/builtin_dialect.h" #include "paddle/pir/dialect/shape/utils/shape_utils.h" #include "paddle/pir/pass/pass.h" diff --git a/paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.cc index 7819bc362f5774..5c93f0ecc5cae4 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.cc @@ -21,7 +21,6 @@ #include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" -#include "paddle/fluid/pir/drr/api/match_context.h" #include "paddle/pir/core/builtin_dialect.h" #include "paddle/pir/pass/pass.h" #include "paddle/pir/pattern_rewrite/pattern_applicator.h" diff --git a/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc index 8c2becde5d990d..9a7db9b7a0a1fb 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc @@ -19,8 +19,7 @@ #include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_util.h" #include "paddle/cinn/hlir/framework/pir/utils.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" -#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" -#include "paddle/fluid/pir/drr/api/match_context.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_base.h" #include "paddle/pir/core/builtin_dialect.h" #include "paddle/pir/core/builtin_op.h" #include "paddle/pir/pass/pass.h" @@ -31,7 +30,7 @@ namespace cinn { namespace dialect { namespace ir { -class SumOpPattern : public paddle::drr::DrrPatternBase { +class SumOpPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { // Source Pattern @@ -55,9 +54,11 @@ class SumOpPattern : public paddle::drr::DrrPatternBase { {"keep_dim", pattern.Attr("keep_dim")}}); res.Tensor("ret") = cinn_reduce_sum(res.Tensor("arg0")); } + + std::string name() const override { return "SumOpPattern"; } }; -class MaxOpPattern : public paddle::drr::DrrPatternBase { +class MaxOpPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { // Source Pattern @@ -80,9 +81,11 @@ class MaxOpPattern : public paddle::drr::DrrPatternBase { {"keep_dim", pattern.Attr("keep_dim")}}); res.Tensor("ret") = cinn_reduce_max(res.Tensor("arg0")); } + + std::string name() const override { return "MaxOpPattern"; } }; -class MinOpPattern : public paddle::drr::DrrPatternBase { +class MinOpPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { // Source Pattern @@ -105,9 +108,11 @@ class MinOpPattern : public paddle::drr::DrrPatternBase { {"keep_dim", pattern.Attr("keep_dim")}}); res.Tensor("ret") = cinn_reduce_max(res.Tensor("arg0")); } + + std::string name() const override { return "MinOpPattern"; } }; -class ProdOpPattern : public paddle::drr::DrrPatternBase { +class ProdOpPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { // Source Pattern @@ -130,6 +135,8 @@ class ProdOpPattern : public paddle::drr::DrrPatternBase { {"keep_dim", pattern.Attr("keep_dim")}}); res.Tensor("ret") = cinn_reduce_max(res.Tensor("arg0")); } + + std::string name() const override { return "ProdOpPattern"; } }; class ScaleOpPattern : public pir::OpRewritePattern { @@ -586,7 +593,7 @@ class ExpandOpPattern } }; -class UniformOpPattern : public paddle::drr::DrrPatternBase { +class UniformOpPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { // Source Pattern @@ -632,6 +639,8 @@ class UniformOpPattern : public paddle::drr::DrrPatternBase { {"diag_val", pattern.Attr("min_value")}}); res.Tensor("ret") = cinn_uniform(); } + + std::string name() const override { return "ProdOpPattern"; } }; PdOpToCinnOpPass::PdOpToCinnOpPass() diff --git a/paddle/cinn/hlir/dialect/operator/transforms/split_generate_shape_into_shape_ops_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/split_generate_shape_into_shape_ops_pass.cc index cec66f7c70e2e1..749e042bbf47b8 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/split_generate_shape_into_shape_ops_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/split_generate_shape_into_shape_ops_pass.cc @@ -24,7 +24,6 @@ #include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" -#include "paddle/fluid/pir/drr/api/match_context.h" #include "paddle/pir/core/builtin_dialect.h" #include "paddle/pir/dialect/shape/utils/dim_expr.h" #include "paddle/pir/pass/pass.h" diff --git a/paddle/fluid/pir/drr/CMakeLists.txt b/paddle/fluid/pir/drr/CMakeLists.txt index fa43d828d05bc8..d35693c674c61f 100644 --- a/paddle/fluid/pir/drr/CMakeLists.txt +++ b/paddle/fluid/pir/drr/CMakeLists.txt @@ -1,4 +1,4 @@ -file(GLOB DRR_SRCS "*.cc" "api/*.cc") +file(GLOB DRR_SRCS "*.cc" "include/*.cc") set(op_creator_gen_file ${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/op_generator/op_creator_drr_gen.py diff --git a/paddle/fluid/pir/drr/README.md b/paddle/fluid/pir/drr/README.md index 6fbac0756ae865..9b9790538d48ac 100644 --- a/paddle/fluid/pir/drr/README.md +++ b/paddle/fluid/pir/drr/README.md @@ -8,9 +8,8 @@ DRR can reduce the development cost of PASS, allowing developers to focus on pro Taking PASS to eliminate redundant CastOp as an example, the code example developed using DRR is as follows: ~~~ c++ -// 1. Inherit specialized template class from DrPatternBase -class RemoveRedundentCastPattern - : public paddle::drr::DrrPatternBase { +// 1. Inherit class from DrPatternBase +class RemoveRedundentCastPattern : public paddle::drr::DrrPatternBase { // 2. Overload operator() void operator()(paddle::drr::DrrPatternContext *ctx) const override { // 3. Define a SourcePattern containing two consecutive CastOps using Op, Tensor, and Attribute @@ -32,6 +31,8 @@ class RemoveRedundentCastPattern res.Op(paddle::dialect::CastOp::name(), {{"dtype", pat.Attr("dtype2")}})(res.Tensor("arg0")); } + + std::string name() const override { return "RemoveRedundentCastPattern"; } }; ~~~ @@ -165,7 +166,7 @@ Attribute Attr(const AttrComputeFunc& attr_compute_func) const ## 3 Example Example 1: Matmul + Add -> FusedGemmEpilogue ~~~ c++ -class FusedLinearPattern : public paddle::drr::DrrPatternBase { +class FusedLinearPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { // Define SourcePattern @@ -193,13 +194,14 @@ class FusedLinearPattern : public paddle::drr::DrrPatternBase Full ~~~ c++ -class FoldExpandToConstantPattern - : public paddle::drr::DrrPatternBase { +class FoldExpandToConstantPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { // Define SourcePattern @@ -226,5 +228,7 @@ class FoldExpandToConstantPattern {"place", pat.Attr("place_1")}}); res.Tensor("ret") = full2(); } + + std::string name() const override { return "FoldExpandToConstantPattern"; } }; ~~~ diff --git a/paddle/fluid/pir/drr/README_cn.md b/paddle/fluid/pir/drr/README_cn.md index 1291bec2954c48..4051a5e547f315 100644 --- a/paddle/fluid/pir/drr/README_cn.md +++ b/paddle/fluid/pir/drr/README_cn.md @@ -8,9 +8,8 @@ DRR ( Declarative Rewrite Rule ) 是来处理这种 DAG-to-DAG 类型的一套 P 以消除冗余 CastOp 的 PASS 为例,使用 DRR 的代码开发示例如下: ~~~ c++ -// 1. 继承 DrrPatternBase 的特化模板类 -class RemoveRedundentCastPattern - : public paddle::drr::DrrPatternBase { +// 1. 继承 DrrPatternBase 类 +class RemoveRedundentCastPattern : public paddle::drr::DrrPatternBase { // 2. 重载 operator() void operator()(paddle::drr::DrrPatternContext *ctx) const override { // 3. 使用 Op、Tensor 和 Attribute 定义一个包含两个连续 CastOp 的 SourcePattern @@ -32,6 +31,8 @@ class RemoveRedundentCastPattern res.Op(paddle::dialect::CastOp::name(), {{"dtype", pat.Attr("dtype2")}})(res.Tensor("arg0")); } + + std::string name() const override { return "RemoveRedundentCastPattern"; } }; ~~~ @@ -168,7 +169,7 @@ Attribute Attr(const AttrComputeFunc& attr_compute_func) const ## 3 使用示例 Example 1: Matmul + Add -> FusedGemmEpilogue ~~~ c++ -class FusedLinearPattern : public paddle::drr::DrrPatternBase { +class FusedLinearPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { // 定义 Source Pattern @@ -196,13 +197,14 @@ class FusedLinearPattern : public paddle::drr::DrrPatternBase Full ~~~ c++ -class FoldExpandToConstantPattern - : public paddle::drr::DrrPatternBase { +class FoldExpandToConstantPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { // 定义 Source Pattern @@ -229,5 +231,7 @@ class FoldExpandToConstantPattern {"place", pat.Attr("place_1")}}); res.Tensor("ret") = full2(); } + + std::string name() const override { return "FoldExpandToConstantPattern"; } }; ~~~ diff --git a/paddle/fluid/pir/drr/api/tensor_interface.cc b/paddle/fluid/pir/drr/api/tensor_interface.cc deleted file mode 100644 index 335f95214887a9..00000000000000 --- a/paddle/fluid/pir/drr/api/tensor_interface.cc +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/pir/drr/api/tensor_interface.h" -#include "paddle/fluid/pir/drr/ir_value.h" - -namespace paddle { -namespace drr { - -bool ShapeInterface::operator==(const ShapeInterface& other) const { - return *shape_ == *other.shape_; -} - -int ShapeInterface::size() const { return shape_->size(); } - -int64_t ShapeInterface::at(int idx) const { return shape_->at(idx); } - -bool DtypeInterface::operator==(const DtypeInterface& other) const { - return *dtype_ == *other.dtype_; -} - -IrDtype DtypeInterface::get() const { return *(this->dtype_); } - -} // namespace drr -} // namespace paddle diff --git a/paddle/fluid/pir/drr/api/tensor_interface.h b/paddle/fluid/pir/drr/api/tensor_interface.h deleted file mode 100644 index 24774f00d5a298..00000000000000 --- a/paddle/fluid/pir/drr/api/tensor_interface.h +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include - -namespace paddle { -namespace drr { - -class IrValue; -class IrShape; -class IrDtype; - -class ShapeInterface final { - public: - bool operator==(const ShapeInterface& other) const; - - int size() const; - - int64_t at(int idx) const; - - private: - explicit ShapeInterface(const IrShape* shape) : shape_(shape) {} - - friend class IrValue; - - const IrShape* shape_; -}; - -class DtypeInterface final { - public: - bool operator==(const DtypeInterface& other) const; - - IrDtype get() const; - - private: - explicit DtypeInterface(const IrDtype* dtype) : dtype_(dtype) {} - - friend class IrValue; - - const IrDtype* dtype_; -}; - -class TensorInterface { - public: - virtual ShapeInterface Shape() const = 0; - virtual DtypeInterface Dtype() const = 0; -}; - -} // namespace drr -} // namespace paddle diff --git a/paddle/fluid/pir/drr/api/match_context.h b/paddle/fluid/pir/drr/include/drr_match_context.h similarity index 85% rename from paddle/fluid/pir/drr/api/match_context.h rename to paddle/fluid/pir/drr/include/drr_match_context.h index 762c86cf8a8e60..4339595b710d45 100644 --- a/paddle/fluid/pir/drr/api/match_context.h +++ b/paddle/fluid/pir/drr/include/drr_match_context.h @@ -17,8 +17,9 @@ #include #include -#include "paddle/fluid/pir/drr/api/tensor_interface.h" -#include "paddle/fluid/pir/drr/ir_operation.h" +namespace pir { +class Value; +} namespace paddle { namespace drr { @@ -30,7 +31,7 @@ class MatchContext final { public: MatchContext(std::shared_ptr impl); - const TensorInterface& Tensor(const std::string& tensor_name) const; + const pir::Value& Tensor(const std::string& tensor_name) const; template T Attr(const std::string& attr_name) const; diff --git a/paddle/fluid/pir/drr/api/drr_pattern_base.h b/paddle/fluid/pir/drr/include/drr_pattern_base.h similarity index 54% rename from paddle/fluid/pir/drr/api/drr_pattern_base.h rename to paddle/fluid/pir/drr/include/drr_pattern_base.h index 18252d536869f7..e079fed999a13b 100644 --- a/paddle/fluid/pir/drr/api/drr_pattern_base.h +++ b/paddle/fluid/pir/drr/include/drr_pattern_base.h @@ -14,28 +14,38 @@ #pragma once -#include "paddle/fluid/pir/drr/api/drr_pattern_context.h" -#include "paddle/fluid/pir/drr/drr_rewrite_pattern.h" +#include +#include + +#include "paddle/fluid/pir/drr/include/drr_match_context.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_context.h" +#include "paddle/fluid/pir/drr/include/drr_rewrite_pattern.h" + +namespace pir { +class IrContext; +} namespace paddle { namespace drr { -template +class DrrRewritePattern; +class DrrPatternContext; + class DrrPatternBase { public: virtual ~DrrPatternBase() = default; - // Define the Drr Pattern. - virtual void operator()(paddle::drr::DrrPatternContext* ctx) const = 0; - - std::unique_ptr Build( - pir::IrContext* ir_context, pir::PatternBenefit benefit = 1) const { - DrrPatternContext drr_context; - this->operator()(&drr_context); - std::string pattern_name = pir::get_type_name(); - return std::make_unique( - pattern_name, drr_context, ir_context, benefit); - } + // Define the drr pattern. + virtual void operator()(drr::DrrPatternContext* ctx) const = 0; + + // Give the drr pattern name. + virtual std::string name() const = 0; + + // Give the drr pattern benefit. + virtual uint32_t benefit() const { return 1; } + + // Build the Drr Pattern. + std::unique_ptr Build(pir::IrContext* ir_context) const; }; } // namespace drr diff --git a/paddle/fluid/pir/drr/api/drr_pattern_context.h b/paddle/fluid/pir/drr/include/drr_pattern_context.h similarity index 95% rename from paddle/fluid/pir/drr/api/drr_pattern_context.h rename to paddle/fluid/pir/drr/include/drr_pattern_context.h index feb0e988aa8822..0539708300ac7c 100644 --- a/paddle/fluid/pir/drr/api/drr_pattern_context.h +++ b/paddle/fluid/pir/drr/include/drr_pattern_context.h @@ -21,8 +21,9 @@ #include #include #include +#include -#include "paddle/fluid/pir/drr/api/match_context.h" +#include "paddle/fluid/pir/drr/include/drr_match_context.h" namespace paddle { namespace drr { @@ -85,17 +86,17 @@ class TensorDataType { std::string tensor_name_; }; +using ConstraintFunction = std::function; class Constraint { public: - explicit Constraint( - const std::function& constrain_fn) + explicit Constraint(const ConstraintFunction& constrain_fn) : IsContextMatchConstraint_(constrain_fn) {} bool operator()(const MatchContext& match_context) const { return IsContextMatchConstraint_(match_context); } private: - std::function IsContextMatchConstraint_; + ConstraintFunction IsContextMatchConstraint_; }; class DrrPatternContext { @@ -132,8 +133,7 @@ class DrrPatternContext { // void RequireEqual(const Attribute& first, const Attribute& second); void RequireEqual(const TensorShape& first, const TensorShape& second); void RequireEqual(const TensorDataType& first, const TensorDataType& second); - void RequireNativeCall( - const std::function& custom_fn); + void RequireNativeCall(const ConstraintFunction& custom_fn); std::shared_ptr source_pattern_graph_; std::vector constraints_; @@ -191,8 +191,6 @@ class Tensor { public: static const char NONE_TENSOR_NAME[]; - const std::string& DebugName() const; - TensorShape shape() const { return TensorShape(name()); } TensorDataType dtype() const { return TensorDataType(name()); } @@ -322,8 +320,7 @@ class SourcePattern { ctx_->RequireEqual(first, second); } - void RequireNativeCall( - const std::function& custom_fn) { + void RequireNativeCall(const ConstraintFunction& custom_fn) { ctx_->RequireNativeCall(custom_fn); } diff --git a/paddle/fluid/pir/drr/drr_rewrite_pattern.h b/paddle/fluid/pir/drr/include/drr_rewrite_pattern.h similarity index 74% rename from paddle/fluid/pir/drr/drr_rewrite_pattern.h rename to paddle/fluid/pir/drr/include/drr_rewrite_pattern.h index 6163c6d9d0193e..11d07b7fca2690 100644 --- a/paddle/fluid/pir/drr/drr_rewrite_pattern.h +++ b/paddle/fluid/pir/drr/include/drr_rewrite_pattern.h @@ -20,42 +20,29 @@ #include #include -#include "paddle/fluid/pir/drr/api/drr_pattern_context.h" -#include "paddle/fluid/pir/drr/api/match_context.h" -#include "paddle/fluid/pir/drr/ir_operation.h" -#include "paddle/fluid/pir/drr/ir_operation_factory.h" -#include "paddle/fluid/pir/drr/match_context_impl.h" -#include "paddle/fluid/pir/drr/pattern_graph.h" -#include "paddle/phi/core/enforce.h" -#include "paddle/pir/core/operation.h" -#include "paddle/pir/core/type_name.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_context.h" #include "paddle/pir/pattern_rewrite/pattern_match.h" +namespace pir { +class IrContext; +} + namespace paddle { namespace drr { +class OpCall; +class Constraint; +class DrrPatternContext; +class MatchContextImpl; +class SourcePatternGraph; +class ResultPatternGraph; + class DrrRewritePattern : public pir::RewritePattern { public: explicit DrrRewritePattern(const std::string& pattern_name, const DrrPatternContext& drr_context, pir::IrContext* context, - pir::PatternBenefit benefit = 1) - : pir::RewritePattern( - drr_context.source_pattern_graph()->AnchorNode()->name(), - benefit, - context, - {}), - pattern_name_(pattern_name), - source_pattern_graph_(drr_context.source_pattern_graph()), - constraints_(drr_context.constraints()), - result_pattern_graph_(drr_context.result_pattern_graph()) { - PADDLE_ENFORCE_NE( - source_pattern_graph_->owned_op_call().empty(), - true, - phi::errors::InvalidArgument("Source pattern graph is empty." - "Suggested fix: Please check the DRR " - "source pattern definition code.")); - } + pir::PatternBenefit benefit); bool MatchAndRewrite( pir::Operation* op, diff --git a/paddle/fluid/pir/drr/ir_operation.h b/paddle/fluid/pir/drr/ir_operation.h deleted file mode 100644 index a88bb3bfff97cf..00000000000000 --- a/paddle/fluid/pir/drr/ir_operation.h +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/pir/core/operation.h" - -namespace paddle { -namespace drr { - -class IrOperation { - public: - explicit IrOperation(pir::Operation* op) : op_(op) {} - - pir::Operation* get() const { return op_; } - - private: - pir::Operation* op_; -}; - -} // namespace drr -} // namespace paddle diff --git a/paddle/fluid/pir/drr/ir_operation_factory.cc b/paddle/fluid/pir/drr/ir_operation_factory.cc index bbc31e9df7c25b..c552550b98c2a7 100644 --- a/paddle/fluid/pir/drr/ir_operation_factory.cc +++ b/paddle/fluid/pir/drr/ir_operation_factory.cc @@ -135,7 +135,7 @@ pir::Value GetIrValueByDrrTensor(const Tensor& tensor, if (tensor.is_none()) { return pir::Value{}; } - return res_match_ctx.GetIrValue(tensor.name()).get(); + return res_match_ctx.GetIrValue(tensor.name()); } std::vector GetIrValuesByDrrTensors( @@ -153,11 +153,7 @@ void BindIrOutputs(const OpCall& op_call, pir::Operation* op, MatchContextImpl* match_ctx) { for (size_t i = 0; i < op_call.outputs().size(); ++i) { - std::shared_ptr ir_value = nullptr; - if (op->result(i)) { - ir_value = std::make_shared(op->result(i)); - } - match_ctx->BindIrValue(op_call.outputs()[i]->name(), ir_value); + match_ctx->BindIrValue(op_call.outputs()[i]->name(), op->result(i)); } } diff --git a/paddle/fluid/pir/drr/ir_operation_factory.h b/paddle/fluid/pir/drr/ir_operation_factory.h index 40682904df62a8..ac59a0310b63f8 100644 --- a/paddle/fluid/pir/drr/ir_operation_factory.h +++ b/paddle/fluid/pir/drr/ir_operation_factory.h @@ -16,7 +16,7 @@ #include -#include "paddle/fluid/pir/drr/api/drr_pattern_context.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_context.h" #include "paddle/fluid/pir/drr/match_context_impl.h" #include "paddle/pir/pattern_rewrite/pattern_match.h" diff --git a/paddle/fluid/pir/drr/ir_value.h b/paddle/fluid/pir/drr/ir_value.h deleted file mode 100644 index ae99fd8c1964e2..00000000000000 --- a/paddle/fluid/pir/drr/ir_value.h +++ /dev/null @@ -1,115 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include - -#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" -#include "paddle/fluid/pir/drr/api/tensor_interface.h" -#include "paddle/pir/core/type.h" -#include "paddle/pir/core/value.h" - -namespace paddle { -namespace drr { - -class IrShape { - public: - explicit IrShape(const phi::DDim& dims) : dims_(dims) {} - - bool operator==(const IrShape& other) const { return dims_ == other.dims_; } - - int size() const { return dims_.size(); } - - int64_t at(int idx) const { return dims_.at(idx); } - - private: - const phi::DDim dims_; -}; - -class IrDtype { - public: - explicit IrDtype(pir::Type dtype) : dtype_(dtype) {} - - bool operator==(IrDtype other) const { return dtype_ == other.dtype_; } - - template - bool isa() const { - return dtype_.isa(); - } - - template - T dyn_cast() const { - return dtype_.dyn_cast(); - } - - private: - const pir::Type dtype_; -}; - -class IrValue : public TensorInterface { - public: - explicit IrValue(const pir::Value& value) - : value_(value), - shape_((value && value.type() && - value.type().dyn_cast()) - ? value.type() - .dyn_cast() - .dims() - : phi::DDim{}), - dtype_((value && value.type() && - value.type().dyn_cast()) - ? value.type() - .dyn_cast() - .dtype() - : pir::Type{}) {} - - ShapeInterface Shape() const override { return ShapeInterface(&shape_); } - DtypeInterface Dtype() const override { return DtypeInterface(&dtype_); } - - explicit operator bool() const { return value_.operator bool(); } - - template - bool isa() const { - return value_.isa(); - } - - template - T dyn_cast() const { - return value_.dyn_cast(); - } - - template - bool type_isa() const { - return value_.type().isa(); - } - - template - T type_dyn_cast() const { - return value_.type().dyn_cast(); - } - - // Don't use it in drr pass! - const pir::Value& get() const { return value_; } - - private: - const pir::Value value_; - const IrShape shape_; - const IrDtype dtype_; -}; - -class IrAttr; - -} // namespace drr -} // namespace paddle diff --git a/paddle/fluid/pir/drr/api/match_context.cc b/paddle/fluid/pir/drr/match_context.cc similarity index 89% rename from paddle/fluid/pir/drr/api/match_context.cc rename to paddle/fluid/pir/drr/match_context.cc index e5f15adf72e75e..3da7b24e5df4a6 100644 --- a/paddle/fluid/pir/drr/api/match_context.cc +++ b/paddle/fluid/pir/drr/match_context.cc @@ -12,11 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/pir/drr/api/match_context.h" - #include -#include "paddle/fluid/pir/drr/ir_operation.h" +#include "paddle/fluid/pir/drr/include/drr_match_context.h" #include "paddle/fluid/pir/drr/match_context_impl.h" namespace paddle { @@ -25,8 +23,7 @@ namespace drr { MatchContext::MatchContext(std::shared_ptr impl) : impl_(impl) {} -const TensorInterface& MatchContext::Tensor( - const std::string& tensor_name) const { +const pir::Value& MatchContext::Tensor(const std::string& tensor_name) const { return impl_->Tensor(tensor_name); } diff --git a/paddle/fluid/pir/drr/match_context_impl.h b/paddle/fluid/pir/drr/match_context_impl.h index b1234d81299360..26c043384069b5 100644 --- a/paddle/fluid/pir/drr/match_context_impl.h +++ b/paddle/fluid/pir/drr/match_context_impl.h @@ -18,12 +18,12 @@ #include #include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" -#include "paddle/fluid/pir/drr/api/drr_pattern_context.h" -#include "paddle/fluid/pir/drr/api/tensor_interface.h" #include "paddle/fluid/pir/drr/attr_type_uilts.h" -#include "paddle/fluid/pir/drr/ir_operation.h" -#include "paddle/fluid/pir/drr/ir_value.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_context.h" #include "paddle/pir/core/builtin_attribute.h" +#include "paddle/pir/core/operation.h" +#include "paddle/pir/core/operation_utils.h" +#include "paddle/pir/core/value.h" namespace paddle { namespace drr { @@ -33,7 +33,7 @@ class MatchContextImpl final { MatchContextImpl() = default; ~MatchContextImpl() = default; - const TensorInterface& Tensor(const std::string& tensor_name) const { + const pir::Value& Tensor(const std::string& tensor_name) const { PADDLE_ENFORCE_NE( tensor_map_.count(tensor_name), 0, @@ -41,10 +41,10 @@ class MatchContextImpl final { "Not found tensor." "The Drr tensor [%s] must exist in pattern graph to be obtained.", tensor_name)); - return *tensor_map_.at(tensor_name); + return tensor_map_.at(tensor_name); } - const IrOperation& Operation(const OpCall* op_call) const { + pir::Operation* IrOperation(const OpCall* op_call) const { PADDLE_ENFORCE_NE( operation_map_.count(op_call), 0, @@ -52,7 +52,7 @@ class MatchContextImpl final { "The Drr operation [%s] must exist in the " "pattern graph to be obtained.", op_call->name())); - return *operation_map_.at(op_call); + return operation_map_.at(op_call); } template @@ -60,7 +60,7 @@ class MatchContextImpl final { return IrAttrTypeCast::To(GetIrAttr(attr_name)); } - const IrValue& GetIrValue(const std::string& tensor_name) const { + pir::Value GetIrValue(const std::string& tensor_name) const { auto iter = tensor_map_.find(tensor_name); PADDLE_ENFORCE_NE( iter, @@ -69,7 +69,7 @@ class MatchContextImpl final { "The Drr tensor [%s] is not found in the map, " "unable to obtain the corresponding IrValue.", tensor_name)); - return *iter->second; + return iter->second; } pir::Attribute GetIrAttr(const std::string& attr_name) const { @@ -84,8 +84,8 @@ class MatchContextImpl final { return iter->second; } - const std::unordered_map>& - operation_map() const { + const std::unordered_map& operation_map() + const { return operation_map_; } @@ -93,18 +93,15 @@ class MatchContextImpl final { return attr_map_; } - const std::unordered_map>& tensor_map() - const { + const std::unordered_map& tensor_map() const { return tensor_map_; } - void BindIrValue(const std::string& value_name, - const std::shared_ptr& value) { + void BindIrValue(const std::string& value_name, const pir::Value& value) { tensor_map_.emplace(value_name, value); } - void BindIrOperation(const OpCall* op_call, - const std::shared_ptr& op) { + void BindIrOperation(const OpCall* op_call, pir::Operation* op) { operation_map_.emplace(op_call, op); const auto& attrs = op_call->attributes(); for (const auto& kv : attrs) { @@ -112,7 +109,7 @@ class MatchContextImpl final { [&](auto&& arg) { if constexpr (std::is_same_v, NormalAttribute>) { - BindIrAttr(arg.name(), op->get()->attribute(kv.first)); + BindIrAttr(arg.name(), op->attribute(kv.first)); } }, kv.second); @@ -124,9 +121,8 @@ class MatchContextImpl final { attr_map_.emplace(attr_name, attr); } - std::unordered_map> tensor_map_; - std::unordered_map> - operation_map_; + std::unordered_map tensor_map_; + std::unordered_map operation_map_; std::unordered_map attr_map_; }; diff --git a/paddle/fluid/pir/drr/api/drr_pattern_context.cc b/paddle/fluid/pir/drr/pattern_context.cc similarity index 89% rename from paddle/fluid/pir/drr/api/drr_pattern_context.cc rename to paddle/fluid/pir/drr/pattern_context.cc index 7f98f0b34cbeb7..a3823ab0e18107 100644 --- a/paddle/fluid/pir/drr/api/drr_pattern_context.cc +++ b/paddle/fluid/pir/drr/pattern_context.cc @@ -12,9 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/pir/drr/api/drr_pattern_context.h" +#include +#include "paddle/fluid/pir/drr/include/drr_pattern_context.h" #include "paddle/fluid/pir/drr/pattern_graph.h" +#include "paddle/fluid/pir/transforms/transform_general_functions.h" #include "paddle/phi/core/enforce.h" namespace paddle { @@ -60,21 +62,13 @@ std::vector DrrPatternContext::constraints() const { return constraints_; } -// void DrrPatternContext::RequireEqual(const Attribute& first, const Attribute& -// second) { -// auto constrain_fn = [&](const MatchContext& match_context) { -// return match_context.Attr(first.id()) == match_context.Attr(second.id()); -// }; -// constraints_.emplace_back(constrain_fn); -// } - void DrrPatternContext::RequireEqual(const TensorShape& first, const TensorShape& second) { // Note: we capture the datas by value for constrain_fn // because the datas are destructed before running constrain_fn. auto constrain_fn = [=](const MatchContext& match_context) { - return match_context.Tensor(first.tensor_name()).Shape() == - match_context.Tensor(second.tensor_name()).Shape(); + return pir::GetShapeFromValue(match_context.Tensor(first.tensor_name())) == + pir::GetShapeFromValue(match_context.Tensor(second.tensor_name())); }; constraints_.emplace_back(constrain_fn); } @@ -84,14 +78,15 @@ void DrrPatternContext::RequireEqual(const TensorDataType& first, // Note: we capture the datas by value for constrain_fn // because the datas are destructed before running constrain_fn. auto constrain_fn = [=](const MatchContext& match_context) { - return match_context.Tensor(first.tensor_name()).Dtype() == - match_context.Tensor(second.tensor_name()).Dtype(); + return pir::GetDataTypeFromValue( + match_context.Tensor(first.tensor_name())) == + pir::GetDataTypeFromValue( + match_context.Tensor(second.tensor_name())); }; constraints_.emplace_back(constrain_fn); } -void DrrPatternContext::RequireNativeCall( - const std::function& custom_fn) { +void DrrPatternContext::RequireNativeCall(const ConstraintFunction& custom_fn) { constraints_.emplace_back(custom_fn); } diff --git a/paddle/fluid/pir/drr/pattern_graph.cc b/paddle/fluid/pir/drr/pattern_graph.cc index 58c79c65acf2f6..5409133b7480b4 100644 --- a/paddle/fluid/pir/drr/pattern_graph.cc +++ b/paddle/fluid/pir/drr/pattern_graph.cc @@ -16,7 +16,7 @@ #include -#include "paddle/fluid/pir/drr/api/drr_pattern_context.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_context.h" #include "paddle/phi/core/enforce.h" namespace paddle { diff --git a/paddle/fluid/pir/drr/drr_rewrite_pattern.cc b/paddle/fluid/pir/drr/rewrite_pattern.cc similarity index 88% rename from paddle/fluid/pir/drr/drr_rewrite_pattern.cc rename to paddle/fluid/pir/drr/rewrite_pattern.cc index d408c1aab13490..5d3726246b36bb 100644 --- a/paddle/fluid/pir/drr/drr_rewrite_pattern.cc +++ b/paddle/fluid/pir/drr/rewrite_pattern.cc @@ -12,11 +12,39 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/pir/drr/drr_rewrite_pattern.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_base.h" +#include "paddle/fluid/pir/drr/include/drr_rewrite_pattern.h" +#include "paddle/fluid/pir/drr/ir_operation_factory.h" +#include "paddle/fluid/pir/drr/match_context_impl.h" +#include "paddle/fluid/pir/drr/pattern_graph.h" + +#include "paddle/phi/core/enforce.h" +#include "paddle/pir/core/operation.h" namespace paddle { namespace drr { +DrrRewritePattern::DrrRewritePattern(const std::string& pattern_name, + const DrrPatternContext& drr_context, + pir::IrContext* context, + pir::PatternBenefit benefit) + : pir::RewritePattern( + drr_context.source_pattern_graph()->AnchorNode()->name(), + benefit, + context, + {}), + pattern_name_(pattern_name), + source_pattern_graph_(drr_context.source_pattern_graph()), + constraints_(drr_context.constraints()), + result_pattern_graph_(drr_context.result_pattern_graph()) { + PADDLE_ENFORCE_NE( + source_pattern_graph_->owned_op_call().empty(), + true, + phi::errors::InvalidArgument("Source pattern graph is empty." + "Suggested fix: Please check the DRR " + "source pattern definition code.")); +} + bool DrrRewritePattern::MatchAndRewrite( pir::Operation* op, pir::PatternRewriter& rewriter) const { // NOLINT @@ -25,6 +53,7 @@ bool DrrRewritePattern::MatchAndRewrite( if (PatternGraphMatch(op, src_match_ctx.get())) { VLOG(4) << "DRR pattern (" << pattern_name_ << ") is matched in program."; PatternGraphRewrite(*src_match_ctx, rewriter); + VLOG(4) << "DRR pattern (" << pattern_name_ << ") is rewrited in program."; return true; } return false; @@ -260,13 +289,11 @@ bool DrrRewritePattern::MatchFromOutputToInput( << ir_node->num_results() << ")."; break; } - source_pattern_match_ctx->BindIrOperation( - drr_node, std::make_shared(ir_node)); + source_pattern_match_ctx->BindIrOperation(drr_node, ir_node); // binding input_tensor of current_op for (size_t i = 0; i < drr_input_tensors.size(); ++i) { - source_pattern_match_ctx->BindIrValue( - drr_input_tensors[i]->name(), - std::make_shared(ir_node->operand(i).source())); + source_pattern_match_ctx->BindIrValue(drr_input_tensors[i]->name(), + ir_node->operand(i).source()); if (ir_node->operand_source(i).isa()) { matched = false; VLOG(8) << drr_node->name() @@ -312,9 +339,8 @@ bool DrrRewritePattern::MatchFromOutputToInput( // binding output tensor of current_op auto drr_op_output_tensor = drr_node->outputs(); for (size_t j = 0; j < drr_op_output_tensor.size(); j++) { - source_pattern_match_ctx->BindIrValue( - drr_op_output_tensor[j]->name(), - std::make_shared(ir_node->result(j))); + source_pattern_match_ctx->BindIrValue(drr_op_output_tensor[j]->name(), + ir_node->result(j)); } ++step; } @@ -379,9 +405,7 @@ MatchContextImpl DrrRewritePattern::CreateOperations( "pattern graph to be obtained.", in_tensor)); if (!result_pattern_graph.id2owend_tensor().at(in_tensor)->is_none()) { - res_match_ctx.BindIrValue( - in_tensor, - std::make_shared(src_match_ctx.GetIrValue(in_tensor))); + res_match_ctx.BindIrValue(in_tensor, src_match_ctx.GetIrValue(in_tensor)); } } @@ -431,9 +455,8 @@ MatchContextImpl DrrRewritePattern::CreateOperations( } if (max_input_op_index == 0UL) { VLOG(6) << "Not found producer op for (" << op_call.name() << ")"; - pir::Operation* source_patter_first_op = - src_match_ctx.Operation(source_pattern_graph.owned_op_call()[0].get()) - .get(); + pir::Operation* source_patter_first_op = src_match_ctx.IrOperation( + source_pattern_graph.owned_op_call()[0].get()); max_input_op_index = op_2_temp_program_index[source_patter_first_op]; rewriter.set_insertion_point(source_patter_first_op); } else { @@ -459,9 +482,8 @@ void DrrRewritePattern::RebindIrTensorForAssignTensor( for (const auto& kv : tensor_assign_map) { const auto& src_tensor_name = kv.first; const auto& dst_tensor_name = kv.second; - res_match_ctx->BindIrValue( - src_tensor_name, - std::make_shared(res_match_ctx->GetIrValue(dst_tensor_name))); + res_match_ctx->BindIrValue(src_tensor_name, + res_match_ctx->GetIrValue(dst_tensor_name)); } } @@ -473,7 +495,7 @@ void DrrRewritePattern::ReplaceOutputTensor( if (source_pattern_graph_->id2owend_tensor().count(output_name)) { const auto& src_ir_tensor = src_match_ctx.GetIrValue(output_name); const auto& res_ir_tensor = res_match_ctx.GetIrValue(output_name); - rewriter.ReplaceAllUsesWith(src_ir_tensor.get(), res_ir_tensor.get()); + rewriter.ReplaceAllUsesWith(src_ir_tensor, res_ir_tensor); } else { LOG(WARNING) << "The output tensor (" << output_name << ") in the result_pattern_graph is not the tensor" @@ -491,7 +513,7 @@ void DrrRewritePattern::DeleteSourcePatternOp( std::unordered_set delete_ops_set; GraphTopo graph_topo_visit(&source_pattern_graph); graph_topo_visit.WalkGraphNodesTopoOrder([&](const OpCall& op_call) { - pir::Operation* op = src_match_ctx.Operation(&op_call).get(); + pir::Operation* op = src_match_ctx.IrOperation(&op_call); VLOG(5) << "DRR delete op: " << op->name() << " pointer: " << op; if (delete_ops_set.count(op) == 0 && op->use_empty()) { delete_ops_que.push(op); @@ -516,5 +538,13 @@ void DrrRewritePattern::DeleteSourcePatternOp( } } +std::unique_ptr DrrPatternBase::Build( + pir::IrContext* ir_context) const { + DrrPatternContext drr_context; + this->operator()(&drr_context); + return std::make_unique( + name(), drr_context, ir_context, benefit()); +} + } // namespace drr } // namespace paddle diff --git a/paddle/fluid/pir/transforms/fusion/attention_fuse_pass.cc b/paddle/fluid/pir/transforms/fusion/attention_fuse_pass.cc index ab19247de4b26a..616ff6f607c588 100644 --- a/paddle/fluid/pir/transforms/fusion/attention_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/attention_fuse_pass.cc @@ -14,15 +14,14 @@ #include "paddle/fluid/pir/transforms/fusion/attention_fuse_pass.h" -#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_base.h" +#include "paddle/fluid/pir/transforms/transform_general_functions.h" #include "paddle/pir/pass/pass.h" #include "paddle/pir/pass/pass_registry.h" -#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" namespace { -class MultiHeadMatmulFusePattern - : public paddle::drr::DrrPatternBase { +class MultiHeadMatmulFusePattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { // @@ -159,7 +158,8 @@ class MultiHeadMatmulFusePattern res.Tensor("concat_1_out") = concat_1(res.Tensor("combine_1_out")); const auto &reshape_5_shape = res.Attr( [](const paddle::drr::MatchContext &match_ctx) -> std::vector { - auto matmul_1_in_2 = match_ctx.Tensor("matmul_1_in_2").Shape(); + auto matmul_1_in_2 = + pir::GetShapeFromValue(match_ctx.Tensor("matmul_1_in_2")); return {-1, 3, matmul_1_in_2.at(1)}; }); const auto &reshape_5 = @@ -216,6 +216,8 @@ class MultiHeadMatmulFusePattern &res.Tensor("add_4_in_2")}, {&res.Tensor("reshape_4_out")}); } + + std::string name() const override { return "MultiHeadMatmulFusePattern"; } }; class AttentionFusePass : public pir::PatternRewritePass { diff --git a/paddle/fluid/pir/transforms/fusion/conv2d_add_act_fuse_pass.cc b/paddle/fluid/pir/transforms/fusion/conv2d_add_act_fuse_pass.cc index a3c71cc90e60d8..8ef58da2b4badf 100644 --- a/paddle/fluid/pir/transforms/fusion/conv2d_add_act_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/conv2d_add_act_fuse_pass.cc @@ -11,17 +11,19 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/pir/pass/pass_registry.h" -#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" -#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/pir/transforms/fusion/conv2d_add_act_fuse_pass.h" -#include "paddle/common/ddim.h" #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/pir/transforms/transform_general_functions.h" +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pass/pass_registry.h" + +#include "paddle/common/ddim.h" + namespace { class Conv2dAddActFusePattern diff --git a/paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.cc b/paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.cc index e86dc04037fa01..fbfa9c6891a55a 100644 --- a/paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.cc @@ -12,24 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/pir/pass/pass_registry.h" -#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" +#include "paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_base.h" -#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" -#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" -#include "paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.h" -#include "paddle/fluid/pir/transforms/transform_general_functions.h" - -#include "paddle/common/ddim.h" -#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" #include "paddle/pir/pass/pass.h" +#include "paddle/pir/pass/pass_registry.h" namespace { -class Conv2dAddFusePattern - : public paddle::drr::DrrPatternBase { +class Conv2dAddFusePattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern pat = ctx->SourcePattern(); @@ -83,6 +76,8 @@ class Conv2dAddFusePattern &res.NoneTensor()}, {&res.Tensor("add_out")}); } + + std::string name() const override { return "Conv2dAddFusePattern"; } }; class Conv2dAddFusePass : public pir::PatternRewritePass { diff --git a/paddle/fluid/pir/transforms/fusion/conv2d_bn_fuse_pass.cc b/paddle/fluid/pir/transforms/fusion/conv2d_bn_fuse_pass.cc index 42129852bc8bc3..eefed9493d58e3 100644 --- a/paddle/fluid/pir/transforms/fusion/conv2d_bn_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/conv2d_bn_fuse_pass.cc @@ -11,17 +11,16 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/pir/pass/pass_registry.h" -#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" -#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/pir/transforms/fusion/conv2d_bn_fuse_pass.h" -#include "paddle/common/ddim.h" -#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" -#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_base.h" #include "paddle/fluid/pir/transforms/transform_general_functions.h" +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pass/pass_registry.h" + namespace { class Conv2dBnFusePattern diff --git a/paddle/fluid/pir/transforms/fusion/fc_elementwise_layernorm_fuse_pass.cc b/paddle/fluid/pir/transforms/fusion/fc_elementwise_layernorm_fuse_pass.cc index 7e5c4bbe8ea187..e57e9b1bef7278 100644 --- a/paddle/fluid/pir/transforms/fusion/fc_elementwise_layernorm_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fc_elementwise_layernorm_fuse_pass.cc @@ -13,16 +13,17 @@ // limitations under the License. #include "paddle/fluid/pir/transforms/fusion/fc_elementwise_layernorm_fuse_pass.h" + #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" -#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_base.h" +#include "paddle/fluid/pir/transforms/transform_general_functions.h" + #include "paddle/pir/pass/pass.h" #include "paddle/pir/pass/pass_registry.h" -#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" namespace { -class FcElementwiseLayerNormFusePattern - : public paddle::drr::DrrPatternBase { +class FcElementwiseLayerNormFusePattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern pat = ctx->SourcePattern(); @@ -49,12 +50,14 @@ class FcElementwiseLayerNormFusePattern // Constrains the activation is none pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { int64_t layer_norm_x = 1; + auto fc_out_dims = pir::GetShapeFromValue(match_ctx.Tensor("fc_out")); + auto w_dims = pir::GetShapeFromValue(match_ctx.Tensor("w")); for (int i = match_ctx.Attr("begin_norm_axis"); - i < match_ctx.Tensor("fc_out").Shape().size(); + i < fc_out_dims.size(); i++) { - layer_norm_x *= match_ctx.Tensor("fc_out").Shape().at(i); + layer_norm_x *= fc_out_dims.at(i); } - if (layer_norm_x == match_ctx.Tensor("w").Shape().at(1)) { + if (layer_norm_x == w_dims.at(1)) { return true; } return false; @@ -89,10 +92,13 @@ class FcElementwiseLayerNormFusePattern &res.Tensor("layernorm_mean"), &res.Tensor("layernorm_variance")}); } + + std::string name() const override { + return "FcElementwiseLayerNormFusePattern"; + } }; -class FcElementwiseLayerNormFuse2Pattern - : public paddle::drr::DrrPatternBase { +class FcElementwiseLayerNormFuse2Pattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern pat = ctx->SourcePattern(); @@ -119,12 +125,14 @@ class FcElementwiseLayerNormFuse2Pattern // Constrains the activation is none pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { int64_t layer_norm_x = 1; + auto fc_out_dims = pir::GetShapeFromValue(match_ctx.Tensor("fc_out")); + auto w_dims = pir::GetShapeFromValue(match_ctx.Tensor("w")); for (int i = match_ctx.Attr("begin_norm_axis"); - i < match_ctx.Tensor("fc_out").Shape().size(); + i < fc_out_dims.size(); i++) { - layer_norm_x *= match_ctx.Tensor("fc_out").Shape().at(i); + layer_norm_x *= fc_out_dims.at(i); } - if (layer_norm_x == match_ctx.Tensor("w").Shape().at(1)) { + if (layer_norm_x == w_dims.at(1)) { return true; } return false; @@ -150,6 +158,10 @@ class FcElementwiseLayerNormFuse2Pattern &res.Tensor("layernorm_mean"), &res.Tensor("layernorm_variance")}); } + + std::string name() const override { + return "FcElementwiseLayerNormFuse2Pattern"; + } }; class FcElementwiseLayerNormFusePass : public pir::PatternRewritePass { diff --git a/paddle/fluid/pir/transforms/fusion/fc_fuse_pass.cc b/paddle/fluid/pir/transforms/fusion/fc_fuse_pass.cc index b49ab9ff4ac77b..18200f2e6b4e2a 100644 --- a/paddle/fluid/pir/transforms/fusion/fc_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fc_fuse_pass.cc @@ -13,15 +13,17 @@ // limitations under the License. #include "paddle/fluid/pir/transforms/fusion/fc_fuse_pass.h" + #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" -#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_base.h" +#include "paddle/fluid/pir/transforms/transform_general_functions.h" + #include "paddle/pir/pass/pass.h" #include "paddle/pir/pass/pass_registry.h" -#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" namespace { -class MatmulAddPattern : public paddle::drr::DrrPatternBase { +class MatmulAddPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern pat = ctx->SourcePattern(); @@ -33,25 +35,22 @@ class MatmulAddPattern : public paddle::drr::DrrPatternBase { pat.Tensor("add_out") = add(pat.Tensor("matmul_out"), pat.Tensor("y")); pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { - if (match_ctx.Tensor("w").Shape().size() != 2 || - match_ctx.Tensor("x").Shape().size() < 2) { + auto w_dims = pir::GetShapeFromValue(match_ctx.Tensor("w")); + auto x_dims = pir::GetShapeFromValue(match_ctx.Tensor("x")); + auto y_dims = pir::GetShapeFromValue(match_ctx.Tensor("y")); + if (w_dims.size() != 2 || x_dims.size() < 2) { return false; } - if (match_ctx.Tensor("x").Shape().at( - match_ctx.Tensor("x").Shape().size() - 1) != - match_ctx.Tensor("w").Shape().at(0) || + if (x_dims.at(x_dims.size() - 1) != w_dims.at(0) || match_ctx.Attr("transpose_x") == true || match_ctx.Attr("transpose_y") == true) { return false; } - if (match_ctx.Tensor("y").Shape().size() == 1) { - return match_ctx.Tensor("y").Shape().at(0) == - match_ctx.Tensor("w").Shape().at(1); + if (y_dims.size() == 1) { + return y_dims.at(0) == w_dims.at(1); } - if (match_ctx.Tensor("y").Shape().size() == 2) { - return match_ctx.Tensor("y").Shape().at(0) == 1 && - match_ctx.Tensor("y").Shape().at(1) == - match_ctx.Tensor("w").Shape().at(1); + if (y_dims.size() == 2) { + return y_dims.at(0) == 1 && y_dims.at(1) == w_dims.at(1); } return false; }); @@ -60,7 +59,8 @@ class MatmulAddPattern : public paddle::drr::DrrPatternBase { const auto &in_num_col_dims_attr = res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any { - return match_ctx.Tensor("x").Shape().size() - 1; + auto x_dims = pir::GetShapeFromValue(match_ctx.Tensor("x")); + return x_dims.size() - 1; }); const auto &false_attr = res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { @@ -79,10 +79,11 @@ class MatmulAddPattern : public paddle::drr::DrrPatternBase { fc({&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("y")}, {&res.Tensor("add_out")}); } + + std::string name() const override { return "MatmulAddPattern"; } }; -class FcWithReluPattern - : public paddle::drr::DrrPatternBase { +class FcWithReluPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern pat = ctx->SourcePattern(); @@ -117,6 +118,8 @@ class FcWithReluPattern fc_with_relu({&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("y")}, {&res.Tensor("relu_out")}); } + + std::string name() const override { return "FcWithReluPattern"; } }; class FcFusePass : public pir::PatternRewritePass { diff --git a/paddle/fluid/pir/transforms/fusion/fused_dot_product_attention_pass.cc b/paddle/fluid/pir/transforms/fusion/fused_dot_product_attention_pass.cc index 9b2e7f2f3f2e74..0b5737ecf69d6d 100644 --- a/paddle/fluid/pir/transforms/fusion/fused_dot_product_attention_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fused_dot_product_attention_pass.cc @@ -13,16 +13,16 @@ // limitations under the License. #include "paddle/fluid/pir/transforms/fusion/fused_dot_product_attention_pass.h" + #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" -#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_base.h" + #include "paddle/pir/pass/pass.h" #include "paddle/pir/pass/pass_registry.h" -#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" namespace { -class FusedDotProductAttentionPattern - : public paddle::drr::DrrPatternBase { +class FusedDotProductAttentionPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern src = ctx->SourcePattern(); @@ -137,10 +137,13 @@ class FusedDotProductAttentionPattern &res.Tensor("softmax_aux"), &res.Tensor("rng_state")}); } + + std::string name() const override { + return "FusedDotProductAttentionPattern"; + } }; -class FusedDotProductAttentionGradPattern - : public paddle::drr::DrrPatternBase { +class FusedDotProductAttentionGradPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern src = ctx->SourcePattern(); @@ -314,11 +317,14 @@ class FusedDotProductAttentionGradPattern &res.Tensor("out_grad")}, {&res.Tensor("q_grad"), &res.Tensor("k_grad"), &res.Tensor("v_grad")}); } + + std::string name() const override { + return "FusedDotProductAttentionGradPattern"; + } }; class FusedDotProductAttentionWithDropoutPattern - : public paddle::drr::DrrPatternBase< - FusedDotProductAttentionWithDropoutPattern> { + : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern src = ctx->SourcePattern(); @@ -441,11 +447,14 @@ class FusedDotProductAttentionWithDropoutPattern &res.Tensor("softmax_aux"), &res.Tensor("rng_state")}); } + + std::string name() const override { + return "FusedDotProductAttentionWithDropoutPattern"; + } }; class FusedDotProductAttentionGradWithDropoutPattern - : public paddle::drr::DrrPatternBase< - FusedDotProductAttentionGradWithDropoutPattern> { + : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern src = ctx->SourcePattern(); @@ -629,12 +638,16 @@ class FusedDotProductAttentionGradWithDropoutPattern &res.Tensor("out_grad")}, {&res.Tensor("q_grad"), &res.Tensor("k_grad"), &res.Tensor("v_grad")}); } + + std::string name() const override { + return "FusedDotProductAttentionGradWithDropoutPattern"; + } }; class FusedDotProductAttentionPass : public pir::PatternRewritePass { public: FusedDotProductAttentionPass() - : pir::PatternRewritePass("fused_dot_product_attention_pass", 1) {} + : pir::PatternRewritePass("fused_dot_product_attention_pass", 2) {} pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override { pir::RewritePatternSet ps(context); diff --git a/paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.cc b/paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.cc index df8b39cfc8676d..0041c70488ffa0 100644 --- a/paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.cc @@ -13,16 +13,16 @@ // limitations under the License. #include "paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.h" + #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" -#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_base.h" + #include "paddle/pir/pass/pass.h" #include "paddle/pir/pass/pass_registry.h" -#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" namespace { -class FusedDropoutAddPattern - : public paddle::drr::DrrPatternBase { +class FusedDropoutAddPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern pat = ctx->SourcePattern(); @@ -50,10 +50,11 @@ class FusedDropoutAddPattern {&res.Tensor("x"), &res.Tensor("y"), &res.Tensor("seed_tensor")}, {&res.Tensor("add_out"), &res.Tensor("mask")}); } + + std::string name() const override { return "FusedDropoutAddPattern"; } }; -class FusedDropoutGradAddGradPattern - : public paddle::drr::DrrPatternBase { +class FusedDropoutGradAddGradPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern pat = ctx->SourcePattern(); @@ -103,12 +104,14 @@ class FusedDropoutGradAddGradPattern fused_dropout_add_grad({&res.Tensor("mask"), &res.Tensor("add_out_grad")}, {&res.Tensor("x_grad"), &res.Tensor("y_grad")}); } + + std::string name() const override { return "FusedDropoutGradAddGradPattern"; } }; class FusedDropoutAddPass : public pir::PatternRewritePass { public: FusedDropoutAddPass() - : pir::PatternRewritePass("fused_dropout_add_pass", 1) {} + : pir::PatternRewritePass("fused_dropout_add_pass", 2) {} pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override { pir::RewritePatternSet ps(context); diff --git a/paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.cc b/paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.cc index 02a6b4744cdcb8..6a39c015893e32 100644 --- a/paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.cc @@ -13,16 +13,17 @@ // limitations under the License. #include "paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.h" + #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" -#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_base.h" +#include "paddle/fluid/pir/transforms/transform_general_functions.h" + #include "paddle/pir/pass/pass.h" #include "paddle/pir/pass/pass_registry.h" -#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" namespace { -class FusedLinearPattern - : public paddle::drr::DrrPatternBase { +class FusedLinearPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern pat = ctx->SourcePattern(); @@ -35,9 +36,11 @@ class FusedLinearPattern pat.Tensor("out") = add(pat.Tensor("tmp"), pat.Tensor("bias")); pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { - return (match_ctx.Tensor("w").Shape().size() == 2 && - match_ctx.Tensor("x").Shape().size() >= 2 && - match_ctx.Tensor("bias").Shape().size() == 1); + auto w_dims = pir::GetShapeFromValue(match_ctx.Tensor("w")); + auto x_dims = pir::GetShapeFromValue(match_ctx.Tensor("x")); + auto bias_dims = pir::GetShapeFromValue(match_ctx.Tensor("bias")); + return (w_dims.size() == 2 && x_dims.size() >= 2 && + bias_dims.size() == 1); }); paddle::drr::ResultPattern res = pat.ResultPattern(); @@ -54,10 +57,11 @@ class FusedLinearPattern {&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("bias")}, {&res.Tensor("out")}); } + + std::string name() const override { return "FusedLinearPattern"; } }; -class FusedLinearGradPattern - : public paddle::drr::DrrPatternBase { +class FusedLinearGradPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern pat = ctx->SourcePattern(); @@ -78,9 +82,11 @@ class FusedLinearGradPattern {&pat.Tensor("x_grad"), &pat.Tensor("w_grad")}); pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { - return (match_ctx.Tensor("w").Shape().size() == 2 && - match_ctx.Tensor("x").Shape().size() >= 2 && - match_ctx.Tensor("bias").Shape().size() == 1); + auto w_dims = pir::GetShapeFromValue(match_ctx.Tensor("w")); + auto x_dims = pir::GetShapeFromValue(match_ctx.Tensor("x")); + auto bias_dims = pir::GetShapeFromValue(match_ctx.Tensor("bias")); + return (w_dims.size() == 2 && x_dims.size() >= 2 && + bias_dims.size() == 1); }); paddle::drr::ResultPattern res = pat.ResultPattern(); @@ -109,10 +115,11 @@ class FusedLinearGradPattern &res.Tensor("w_grad"), &res.Tensor("bias_grad")}); } + + std::string name() const override { return "FusedLinearGradPattern"; } }; -class FusedLinearGeluPattern - : public paddle::drr::DrrPatternBase { +class FusedLinearGeluPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern pat = ctx->SourcePattern(); @@ -148,9 +155,11 @@ class FusedLinearGeluPattern {&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("bias")}, {&res.Tensor("out"), &res.Tensor("reserve_space")}); } + + std::string name() const override { return "FusedLinearGeluPattern"; } }; -class FusedLinearReluPattern - : public paddle::drr::DrrPatternBase { + +class FusedLinearReluPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern pat = ctx->SourcePattern(); @@ -186,10 +195,11 @@ class FusedLinearReluPattern {&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("bias")}, {&res.Tensor("out"), &res.Tensor("reserve_space")}); } + + std::string name() const override { return "FusedLinearReluPattern"; } }; -class FusedLinearGeluGradPattern - : public paddle::drr::DrrPatternBase { +class FusedLinearGeluGradPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern pat = ctx->SourcePattern(); @@ -254,10 +264,11 @@ class FusedLinearGeluGradPattern &res.Tensor("w1_grad"), &res.Tensor("bias1_grad")}); } + + std::string name() const override { return "FusedLinearGeluGradPattern"; } }; -class FusedLinearReluGradPattern - : public paddle::drr::DrrPatternBase { +class FusedLinearReluGradPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern pat = ctx->SourcePattern(); @@ -322,6 +333,8 @@ class FusedLinearReluGradPattern &res.Tensor("w1_grad"), &res.Tensor("bias1_grad")}); } + + std::string name() const override { return "FusedLinearReluGradPattern"; } }; class FusedGemmEpiloguePass : public pir::PatternRewritePass { diff --git a/paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.cc b/paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.cc index 8c93ff98226754..1453426cc8df6f 100644 --- a/paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.cc @@ -13,16 +13,18 @@ // limitations under the License. #include "paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.h" + #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" -#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_base.h" +#include "paddle/fluid/pir/transforms/transform_general_functions.h" + #include "paddle/pir/pass/pass.h" #include "paddle/pir/pass/pass_registry.h" -#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" + namespace { // add_grad + matmul_grad + add_ -> matmul + fused_liner_param_gard_add -class FusedMatmulAddGradAddPattern - : public paddle::drr::DrrPatternBase { +class FusedMatmulAddGradAddPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern pat = ctx->SourcePattern(); @@ -51,18 +53,22 @@ class FusedMatmulAddGradAddPattern pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { const auto &x_trans = match_ctx.Attr("trans_x"); const auto &y_trans = match_ctx.Attr("trans_y"); - return (match_ctx.Tensor("weight_grad").Shape() == - match_ctx.Tensor("dweight").Shape() && - match_ctx.Tensor("out").Shape() == - match_ctx.Tensor("fwd_add_out_grad").Shape() && - x_trans == false && y_trans == false); + auto weight_grad_dims = + pir::GetShapeFromValue(match_ctx.Tensor("weight_grad")); + auto dweight_dims = pir::GetShapeFromValue(match_ctx.Tensor("dweight")); + auto out_dims = pir::GetShapeFromValue(match_ctx.Tensor("out")); + auto fwd_add_out_grad_dims = + pir::GetShapeFromValue(match_ctx.Tensor("fwd_add_out_grad")); + return (weight_grad_dims == dweight_dims && + out_dims == fwd_add_out_grad_dims && x_trans == false && + y_trans == false); }); paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &muti_precision_attr = res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { - return !(match_ctx.Tensor("dweight").Dtype() == - match_ctx.Tensor("weight_grad").Dtype()); + return !(pir::GetDataTypeFromValue(match_ctx.Tensor("dweight")) == + pir::GetDataTypeFromValue(match_ctx.Tensor("weight_grad"))); }); const auto &true_attr = @@ -89,11 +95,12 @@ class FusedMatmulAddGradAddPattern &res.NoneTensor()}, {&res.Tensor("add_out"), &res.Tensor("dbias")}); } + + std::string name() const override { return "FusedMatmulAddGradAddPattern"; } }; // matmul_grad + add_ -> matmul + fused_liner_param_gard_add -class FusedMatmulGradAddPattern - : public paddle::drr::DrrPatternBase { +class FusedMatmulGradAddPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern pat = ctx->SourcePattern(); @@ -111,17 +118,19 @@ class FusedMatmulGradAddPattern pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { const auto &x_trans = match_ctx.Attr("trans_x"); const auto &y_trans = match_ctx.Attr("trans_y"); - return (match_ctx.Tensor("weight_grad").Shape() == - match_ctx.Tensor("dweight").Shape() && - x_trans == false && y_trans == false); + auto weight_grad_dims = + pir::GetShapeFromValue(match_ctx.Tensor("weight_grad")); + auto dweight_dims = pir::GetShapeFromValue(match_ctx.Tensor("dweight")); + return (weight_grad_dims == dweight_dims && x_trans == false && + y_trans == false); }); paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &muti_precision_attr = res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { - return !(match_ctx.Tensor("dweight").Dtype() == - match_ctx.Tensor("weight_grad").Dtype()); + return !(pir::GetDataTypeFromValue(match_ctx.Tensor("dweight")) == + pir::GetDataTypeFromValue(match_ctx.Tensor("weight_grad"))); }); const auto &true_attr = @@ -149,11 +158,12 @@ class FusedMatmulGradAddPattern &res.NoneTensor()}, {&res.Tensor("add_out"), &res.Tensor("dbias_out")}); } + + std::string name() const override { return "FusedMatmulGradAddPattern"; } }; // matmul + 0 = add_(0,1) -> fused_liner_param_gard_add -class FusedMatmulAddaPattern - : public paddle::drr::DrrPatternBase { +class FusedMatmulAddaPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern pat = ctx->SourcePattern(); @@ -168,15 +178,17 @@ class FusedMatmulAddaPattern add_(pat.Tensor("dweight"), pat.Tensor("weight_grad")); pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { - return (match_ctx.Tensor("weight_grad").Shape() == - match_ctx.Tensor("dweight").Shape()); + auto weight_grad_dims = + pir::GetShapeFromValue(match_ctx.Tensor("weight_grad")); + auto dweight_dims = pir::GetShapeFromValue(match_ctx.Tensor("dweight")); + return (weight_grad_dims == dweight_dims); }); paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &muti_precision_attr = res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { - return !(match_ctx.Tensor("dweight").Dtype() == - match_ctx.Tensor("weight_grad").Dtype()); + return !(pir::GetDataTypeFromValue(match_ctx.Tensor("dweight")) == + pir::GetDataTypeFromValue(match_ctx.Tensor("weight_grad"))); }); const auto &true_attr = @@ -198,11 +210,12 @@ class FusedMatmulAddaPattern &res.NoneTensor()}, {&res.Tensor("add_out"), &res.Tensor("dbias_out")}); } + + std::string name() const override { return "FusedMatmulAddaPattern"; } }; // matmul + 1 = add_(1,0) -> fused_liner_param_gard_add -class FusedMatmulAddbPattern - : public paddle::drr::DrrPatternBase { +class FusedMatmulAddbPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern pat = ctx->SourcePattern(); @@ -217,15 +230,17 @@ class FusedMatmulAddbPattern add_(pat.Tensor("weight_grad"), pat.Tensor("dweight")); pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { - return (match_ctx.Tensor("weight_grad").Shape() == - match_ctx.Tensor("dweight").Shape()); + auto weight_grad_dims = + pir::GetShapeFromValue(match_ctx.Tensor("weight_grad")); + auto dweight_dims = pir::GetShapeFromValue(match_ctx.Tensor("dweight")); + return (weight_grad_dims == dweight_dims); }); paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &muti_precision_attr = res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { - return !(match_ctx.Tensor("dweight").Dtype() == - match_ctx.Tensor("weight_grad").Dtype()); + return !(pir::GetDataTypeFromValue(match_ctx.Tensor("dweight")) == + pir::GetDataTypeFromValue(match_ctx.Tensor("weight_grad"))); }); const auto &true_attr = @@ -247,11 +262,12 @@ class FusedMatmulAddbPattern &res.NoneTensor()}, {&res.Tensor("add_out"), &res.Tensor("dbias_out")}); } + + std::string name() const override { return "FusedMatmulAddbPattern"; } }; // add_grad + matmul + 0 = add_(0,1) -> fused_liner_param_gard_add -class FusedMatmulAddGradAddaPattern - : public paddle::drr::DrrPatternBase { +class FusedMatmulAddGradAddaPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern pat = ctx->SourcePattern(); @@ -278,17 +294,19 @@ class FusedMatmulAddGradAddaPattern add_(pat.Tensor("dweight"), pat.Tensor("weight_grad")); pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { - return (match_ctx.Tensor("weight_grad").Shape() == - match_ctx.Tensor("dweight").Shape() && - match_ctx.Tensor("out").Shape() == - match_ctx.Tensor("dadd_out").Shape()); + auto weight_grad_dims = + pir::GetShapeFromValue(match_ctx.Tensor("weight_grad")); + auto dweight_dims = pir::GetShapeFromValue(match_ctx.Tensor("dweight")); + auto out_dims = pir::GetShapeFromValue(match_ctx.Tensor("out")); + auto dadd_out_dims = pir::GetShapeFromValue(match_ctx.Tensor("dadd_out")); + return (weight_grad_dims == dweight_dims && out_dims == dadd_out_dims); }); paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &muti_precision_attr = res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { - return !(match_ctx.Tensor("dweight").Dtype() == - match_ctx.Tensor("weight_grad").Dtype()); + return !(pir::GetDataTypeFromValue(match_ctx.Tensor("dweight")) == + pir::GetDataTypeFromValue(match_ctx.Tensor("weight_grad"))); }); const auto &true_attr = res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { @@ -304,11 +322,12 @@ class FusedMatmulAddGradAddaPattern &res.NoneTensor()}, {&res.Tensor("dweight_out"), &res.Tensor("dbias")}); } + + std::string name() const override { return "FusedMatmulAddGradAddaPattern"; } }; // add_grad + matmul + 1 = add_(1,0) -> fused_liner_param_gard_add -class FusedMatmulAddGradAddbPattern - : public paddle::drr::DrrPatternBase { +class FusedMatmulAddGradAddbPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern pat = ctx->SourcePattern(); @@ -335,17 +354,19 @@ class FusedMatmulAddGradAddbPattern add_(pat.Tensor("weight_grad"), pat.Tensor("dweight")); pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { - return (match_ctx.Tensor("weight_grad").Shape() == - match_ctx.Tensor("dweight").Shape() && - match_ctx.Tensor("out").Shape() == - match_ctx.Tensor("dadd_out").Shape()); + auto weight_grad_dims = + pir::GetShapeFromValue(match_ctx.Tensor("weight_grad")); + auto dweight_dims = pir::GetShapeFromValue(match_ctx.Tensor("dweight")); + auto out_dims = pir::GetShapeFromValue(match_ctx.Tensor("out")); + auto dadd_out_dims = pir::GetShapeFromValue(match_ctx.Tensor("dadd_out")); + return (weight_grad_dims == dweight_dims && out_dims == dadd_out_dims); }); paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &muti_precision_attr = res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { - return !(match_ctx.Tensor("dweight").Dtype() == - match_ctx.Tensor("weight_grad").Dtype()); + return !(pir::GetDataTypeFromValue(match_ctx.Tensor("dweight")) == + pir::GetDataTypeFromValue(match_ctx.Tensor("weight_grad"))); }); const auto &true_attr = res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { @@ -361,12 +382,14 @@ class FusedMatmulAddGradAddbPattern &res.NoneTensor()}, {&res.Tensor("dweight_out"), &res.Tensor("dbias")}); } + + std::string name() const override { return "FusedMatmulAddGradAddbPattern"; } }; class FusedLinearParamGradAddPass : public pir::PatternRewritePass { public: FusedLinearParamGradAddPass() - : pir::PatternRewritePass("fused_linear_param_grad_add_pass", 1) {} + : pir::PatternRewritePass("fused_linear_param_grad_add_pass", 2) {} pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override { pir::RewritePatternSet ps(context); diff --git a/paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.cc b/paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.cc index 82864f3d80e88f..df61b1eb25ba27 100644 --- a/paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.cc @@ -13,13 +13,15 @@ // limitations under the License. #include "paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.h" + #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" -#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_base.h" +#include "paddle/fluid/pir/transforms/transform_general_functions.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/place.h" + #include "paddle/pir/pass/pass.h" #include "paddle/pir/pass/pass_registry.h" -#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" namespace { @@ -35,8 +37,7 @@ int getSMVersion() { return sm_version; } -class FusedWeightOnlyLinearPattern - : public paddle::drr::DrrPatternBase { +class FusedWeightOnlyLinearPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { // @@ -63,21 +64,21 @@ class FusedWeightOnlyLinearPattern bool matmul_trans_y = match_ctx.Attr("matmul_transpose_y"); if (matmul_trans_x || matmul_trans_y) return false; - if (!(match_ctx.Tensor("w").Shape().size() == 2 && - match_ctx.Tensor("x").Shape().size() >= 2 && - match_ctx.Tensor("bias").Shape().size() == 1)) { + auto w_dims = pir::GetShapeFromValue(match_ctx.Tensor("w")); + auto x_dims = pir::GetShapeFromValue(match_ctx.Tensor("x")); + auto bias_dims = pir::GetShapeFromValue(match_ctx.Tensor("bias")); + if (!(w_dims.size() == 2 && x_dims.size() >= 2 && + bias_dims.size() == 1)) { return false; } - auto w_dims = match_ctx.Tensor("w").Shape(); if (w_dims.at(0) % 64 != 0 || w_dims.at(1) % 16 != 0) return false; - auto w_dtype = match_ctx.Tensor("w").Dtype().get(); + auto w_dtype = pir::GetDataTypeFromValue(match_ctx.Tensor("w")); if (!w_dtype.isa() && !w_dtype.isa()) return false; - auto x_dims = match_ctx.Tensor("x").Shape(); if (x_dims.at(x_dims.size() - 1) != w_dims.at(1)) return false; return true; @@ -126,6 +127,8 @@ class FusedWeightOnlyLinearPattern &res.Tensor("weight_scale_tensor")}, {&res.Tensor("add_out")}); } + + std::string name() const override { return "FusedWeightOnlyLinearPattern"; } }; class FusedWeightOnlyLinearPass : public pir::PatternRewritePass { diff --git a/paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.cc b/paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.cc index 0bced0b8ec823f..cabd7a7274cb70 100644 --- a/paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.cc @@ -13,22 +13,17 @@ // limitations under the License. #include "paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.h" -#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" -#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" + #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" -#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_base.h" #include "paddle/fluid/pir/transforms/transform_general_functions.h" -#include "paddle/common/ddim.h" - #include "paddle/pir/pass/pass.h" #include "paddle/pir/pass/pass_registry.h" -#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" namespace { -class MatmulScaleFusePattern - : public paddle::drr::DrrPatternBase { +class MatmulScaleFusePattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern pat = ctx->SourcePattern(); @@ -76,6 +71,8 @@ class MatmulScaleFusePattern matmul_op_res({&res.Tensor("x"), &res.Tensor("scale_res_out")}, {&res.Tensor("scale_out")}); } + + std::string name() const override { return "MatmulScaleFusePattern"; } }; class MatmulScaleFusePass : public pir::PatternRewritePass { diff --git a/paddle/fluid/pir/transforms/identity_op_clean_pass.cc b/paddle/fluid/pir/transforms/identity_op_clean_pass.cc index ac49d494d1c731..53210443eda4e1 100644 --- a/paddle/fluid/pir/transforms/identity_op_clean_pass.cc +++ b/paddle/fluid/pir/transforms/identity_op_clean_pass.cc @@ -13,26 +13,17 @@ // limitations under the License. #include "paddle/fluid/pir/transforms/identity_op_clean_pass.h" -#include -#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" -#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" -#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" -#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" -#include "paddle/fluid/pir/drr/ir_value.h" -#include "paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.h" -#include "paddle/fluid/pir/transforms/transform_general_functions.h" -#include "paddle/common/ddim.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_base.h" #include "paddle/pir/core/builtin_op.h" #include "paddle/pir/pass/pass.h" #include "paddle/pir/pass/pass_registry.h" -#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" namespace { -class RemoveUselessScalePattern - : public paddle::drr::DrrPatternBase { +class RemoveUselessScalePattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern pat = ctx->SourcePattern(); @@ -55,10 +46,11 @@ class RemoveUselessScalePattern paddle::drr::ResultPattern res = pat.ResultPattern(); res.Tensor("scale_out").Assign(res.Tensor("x")); } + + std::string name() const override { return "RemoveUselessScalePattern"; } }; -class RemoveRedundentScalePattern - : public paddle::drr::DrrPatternBase { +class RemoveRedundentScalePattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern pat = ctx->SourcePattern(); @@ -126,10 +118,11 @@ class RemoveRedundentScalePattern scale_op_res({&res.Tensor("x"), &full_op_res()}, {&res.Tensor("scale_2_out")}); } + + std::string name() const override { return "RemoveRedundentScalePattern"; } }; -class RemoveUselessCastPattern - : public paddle::drr::DrrPatternBase { +class RemoveUselessCastPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { auto pat = ctx->SourcePattern(); @@ -138,10 +131,11 @@ class RemoveUselessCastPattern auto res = pat.ResultPattern(); res.Tensor("ret").Assign(res.Tensor("arg0")); } + + std::string name() const override { return "RemoveUselessCastPattern"; } }; -class RemoveUselessConcatPattern - : public paddle::drr::DrrPatternBase { +class RemoveUselessConcatPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { auto pat = ctx->SourcePattern(); @@ -150,18 +144,18 @@ class RemoveUselessConcatPattern pat.Tensor("out") = pat.Op(paddle::dialect::ConcatOp::name())( pat.Tensor("combine_out"), pat.Tensor("axis")); pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { - auto combine_out = dynamic_cast( - match_ctx.Tensor("combine_out")); - return combine_out.type_isa() && - combine_out.type_dyn_cast().size() == 1; + auto combine_out = match_ctx.Tensor("combine_out"); + return combine_out.type().isa() && + combine_out.type().dyn_cast().size() == 1; }); auto res = pat.ResultPattern(); res.Tensor("out").Assign(res.Tensor("x")); } + + std::string name() const override { return "RemoveUselessConcatPattern"; } }; -class RemoveRedundentCastPattern - : public paddle::drr::DrrPatternBase { +class RemoveRedundentCastPattern : public paddle::drr::DrrPatternBase { void operator()(paddle::drr::DrrPatternContext *ctx) const override { auto pat = ctx->SourcePattern(); pat.Tensor("tmp") = pat.Op( @@ -172,10 +166,11 @@ class RemoveRedundentCastPattern res.Tensor("ret") = res.Op( "pd_op.cast", {{"dtype", pat.Attr("dtype2")}})(res.Tensor("arg0")); } + + std::string name() const override { return "RemoveRedundentCastPattern"; } }; -class RemoveRedundentTransposePattern - : public paddle::drr::DrrPatternBase { +class RemoveRedundentTransposePattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern pat = ctx->SourcePattern(); @@ -202,6 +197,10 @@ class RemoveRedundentTransposePattern res.Tensor("ret") = tranpose_continuous(res.Tensor("arg_transpose")); } + + std::string name() const override { + return "RemoveRedundentTransposePattern"; + } }; class IdentityOpCleanPass : public pir::PatternRewritePass { diff --git a/paddle/fluid/pir/transforms/replace_fetch_with_shadow_output_pass.cc b/paddle/fluid/pir/transforms/replace_fetch_with_shadow_output_pass.cc index 5e499436ec7f6b..8029cfc9ddbf5e 100644 --- a/paddle/fluid/pir/transforms/replace_fetch_with_shadow_output_pass.cc +++ b/paddle/fluid/pir/transforms/replace_fetch_with_shadow_output_pass.cc @@ -18,9 +18,6 @@ #include "paddle/pir/core/builtin_op.h" #include "paddle/pir/pass/pass.h" #include "paddle/pir/pass/pass_registry.h" -#include "paddle/pir/pattern_rewrite/frozen_rewrite_pattern_set.h" -#include "paddle/pir/pattern_rewrite/pattern_match.h" -#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" namespace { @@ -39,24 +36,15 @@ class ReplaceFetchWithShadowOutputPattern } }; -class ReplaceFetchWithShadowOutputPass : public pir::Pass { +class ReplaceFetchWithShadowOutputPass : public pir::PatternRewritePass { public: ReplaceFetchWithShadowOutputPass() - : pir::Pass("replace_fetch_with_shadow_output_pass", 0) {} + : pir::PatternRewritePass("replace_fetch_with_shadow_output_pass", 0) {} - bool Initialize(pir::IrContext* context) override { + pir::RewritePatternSet InitializePatterns(pir::IrContext* context) override { pir::RewritePatternSet ps(context); ps.Add(context); - patterns_ = pir::FrozenRewritePatternSet(std::move(ps)); - return true; - } - - void Run(pir::Operation* op) override { - pir::GreedyRewriteConfig cfg; - cfg.use_top_down_traversal = true; - cfg.max_iterations = 10; - auto [_, num_rewrites] = pir::ApplyPatternsGreedily(op, patterns_, cfg); - AddStatistics(num_rewrites); + return ps; } bool CanApplyOn(pir::Operation* op) const override { diff --git a/paddle/pir/core/visitors.h b/paddle/pir/core/visitors.h index 3fdcb71bff9b93..7d9e9eacf4394a 100644 --- a/paddle/pir/core/visitors.h +++ b/paddle/pir/core/visitors.h @@ -41,8 +41,8 @@ void Walk(Operation *op, template void Walk(Operation *op, FuncTy &&callback) { - return detail::Walk(op, callback, Order); + return Walk(op, callback, Order); } - } // namespace detail + } // namespace pir diff --git a/paddle/pir/pass/pass.cc b/paddle/pir/pass/pass.cc index 2f9cb896215ddc..c04669317ef169 100644 --- a/paddle/pir/pass/pass.cc +++ b/paddle/pir/pass/pass.cc @@ -23,6 +23,8 @@ #include "paddle/pir/pass/pass_adaptor.h" #include "paddle/pir/pass/pass_instrumentation.h" #include "paddle/pir/pass/pass_manager.h" +#include "paddle/pir/pattern_rewrite/pattern_match.h" +#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" namespace pir { diff --git a/paddle/pir/pass/pass.h b/paddle/pir/pass/pass.h index 6c2c565322bf87..f85c7519cbe197 100644 --- a/paddle/pir/pass/pass.h +++ b/paddle/pir/pass/pass.h @@ -21,9 +21,8 @@ #include #include "paddle/common/enforce.h" -#include "paddle/pir/core/builtin_op.h" #include "paddle/pir/pass/analysis_manager.h" -#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" +#include "paddle/pir/pattern_rewrite/frozen_rewrite_pattern_set.h" namespace pir { @@ -197,7 +196,7 @@ class IR_API Pass { std::unordered_map> attr_dels_; }; -class PatternRewritePass : public Pass { +class IR_API PatternRewritePass : public Pass { public: PatternRewritePass(const std::string& name, uint8_t opt_level, diff --git a/paddle/pir/pattern_rewrite/pattern_applicator.cc b/paddle/pir/pattern_rewrite/pattern_applicator.cc index 6e45768542061f..f67e41255a33ec 100644 --- a/paddle/pir/pattern_rewrite/pattern_applicator.cc +++ b/paddle/pir/pattern_rewrite/pattern_applicator.cc @@ -14,8 +14,8 @@ #include +#include "paddle/pir/pattern_rewrite/frozen_rewrite_pattern_set.h" #include "paddle/pir/pattern_rewrite/pattern_applicator.h" - #include "paddle/pir/pattern_rewrite/pattern_match.h" namespace pir { diff --git a/paddle/pir/pattern_rewrite/pattern_applicator.h b/paddle/pir/pattern_rewrite/pattern_applicator.h index a0fdf58fd57e0c..37c0a42cbf974d 100644 --- a/paddle/pir/pattern_rewrite/pattern_applicator.h +++ b/paddle/pir/pattern_rewrite/pattern_applicator.h @@ -21,11 +21,14 @@ #include "paddle/pir/core/op_info.h" #include "paddle/pir/core/operation.h" -#include "paddle/pir/pattern_rewrite/frozen_rewrite_pattern_set.h" #include "paddle/pir/pattern_rewrite/pattern_match.h" namespace pir { +class FrozenRewritePatternSet; +class RewritePattern; +class Pattern; + class PatternApplicator { public: using CostModel = std::function; diff --git a/paddle/pir/pattern_rewrite/pattern_match.h b/paddle/pir/pattern_rewrite/pattern_match.h index a0c34d8f58f073..475779f99cb287 100644 --- a/paddle/pir/pattern_rewrite/pattern_match.h +++ b/paddle/pir/pattern_rewrite/pattern_match.h @@ -37,7 +37,7 @@ namespace pir { // This class reprensents the benefit of a pattern. The most common -// unit to use is the `numver of operations` in the pattern. +// unit to use is the `number of operations` in the pattern. class IR_API PatternBenefit { public: PatternBenefit() = default; diff --git a/paddle/pir/pattern_rewrite/pattern_rewrite_driver.cc b/paddle/pir/pattern_rewrite/pattern_rewrite_driver.cc index 3788c63273ffa9..c138785038d5a2 100644 --- a/paddle/pir/pattern_rewrite/pattern_rewrite_driver.cc +++ b/paddle/pir/pattern_rewrite/pattern_rewrite_driver.cc @@ -227,4 +227,20 @@ std::pair ApplyPatternsGreedily( return std::make_pair(converged, num_rewrites); } +IR_API std::pair ApplyPatternsGreedily( + Operation* op, + const FrozenRewritePatternSet& patterns, + GreedyRewriteConfig config) { + bool sum_converged = true; + int64_t sum_num_rewrites = 0; + for (uint32_t i = 0; i < op->num_regions(); ++i) { + Region& region = op->region(i); + auto [converged, num_rewrites] = + ApplyPatternsGreedily(region, patterns, config); + sum_converged &= converged; + sum_num_rewrites += num_rewrites; + } + return std::make_pair(sum_converged, sum_num_rewrites); +} + } // namespace pir diff --git a/paddle/pir/pattern_rewrite/pattern_rewrite_driver.h b/paddle/pir/pattern_rewrite/pattern_rewrite_driver.h index 8186e3cadb1958..8ed55843e2adb9 100644 --- a/paddle/pir/pattern_rewrite/pattern_rewrite_driver.h +++ b/paddle/pir/pattern_rewrite/pattern_rewrite_driver.h @@ -16,11 +16,11 @@ #include "paddle/pir/core/dll_decl.h" #include "paddle/pir/core/region.h" -#include "paddle/pir/pattern_rewrite/frozen_rewrite_pattern_set.h" -#include "paddle/pir/pattern_rewrite/pattern_match.h" namespace pir { +class FrozenRewritePatternSet; + /// This enum will control which ops will be added to the worklist during the /// match rewrite process enum class IR_API GreedyRewriteStrictness { @@ -73,20 +73,9 @@ ApplyPatternsGreedily(Region& region, // NOLINT GreedyRewriteConfig config = GreedyRewriteConfig()); /// Perform a match and rewrite process for all regions of a given op. -inline IR_API std::pair ApplyPatternsGreedily( +IR_API std::pair ApplyPatternsGreedily( Operation* op, const FrozenRewritePatternSet& patterns, - GreedyRewriteConfig config = GreedyRewriteConfig()) { - bool sum_converged = true; - int64_t sum_num_rewrites = 0; - for (uint32_t i = 0; i < op->num_regions(); ++i) { - Region& region = op->region(i); - auto [converged, num_rewrites] = - ApplyPatternsGreedily(region, patterns, config); - sum_converged &= converged; - sum_num_rewrites += num_rewrites; - } - return std::make_pair(sum_converged, sum_num_rewrites); -} + GreedyRewriteConfig config = GreedyRewriteConfig()); } // namespace pir diff --git a/test/cpp/pir/cinn/dialect_convert_test.cc b/test/cpp/pir/cinn/dialect_convert_test.cc index 398c0892688300..f67e55cade1f3b 100644 --- a/test/cpp/pir/cinn/dialect_convert_test.cc +++ b/test/cpp/pir/cinn/dialect_convert_test.cc @@ -21,7 +21,7 @@ #include "paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.h" #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" -#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_base.h" #include "paddle/pir/core/builtin_dialect.h" #include "paddle/pir/pass/pass.h" #include "paddle/pir/pass/pass_manager.h" diff --git a/test/cpp/pir/pattern_rewrite/drr_same_type_binding_test.cc b/test/cpp/pir/pattern_rewrite/drr_same_type_binding_test.cc index 1a938e7f600b78..342311cf76b773 100644 --- a/test/cpp/pir/pattern_rewrite/drr_same_type_binding_test.cc +++ b/test/cpp/pir/pattern_rewrite/drr_same_type_binding_test.cc @@ -18,7 +18,7 @@ #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" -#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_base.h" #include "paddle/fluid/pir/transforms/dead_code_elimination_pass.h" #include "paddle/pir/core/builtin_dialect.h" #include "paddle/pir/pass/pass.h" @@ -49,11 +49,10 @@ output0 output1 output2 output3 output4 output5 output6 */ -class SameTypeBindingTestPattern - // This class is for test cases of the same type of OP. - // (without considering the computational logic between OPs, - // only focusing on the process of matching and replacing) - : public paddle::drr::DrrPatternBase { +// This class is for test cases of the same type of OP. +// (without considering the computational logic between OPs, +// only focusing on the process of matching and replacing) +class SameTypeBindingTestPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern src = ctx->SourcePattern(); @@ -179,6 +178,8 @@ class SameTypeBindingTestPattern res.Tensor("output5") = full_5(); res.Tensor("output6") = full_6(); } + + std::string name() const override { return "SameTypeBindingTestPattern"; } }; void BuildProgram(pir::Builder &builder) { // NOLINT diff --git a/test/cpp/pir/pattern_rewrite/drr_test.cc b/test/cpp/pir/pattern_rewrite/drr_test.cc index 54b5ff2025e49d..6efe87d8ca70c4 100644 --- a/test/cpp/pir/pattern_rewrite/drr_test.cc +++ b/test/cpp/pir/pattern_rewrite/drr_test.cc @@ -18,13 +18,12 @@ #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" -#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_base.h" #include "paddle/fluid/pir/transforms/dead_code_elimination_pass.h" #include "paddle/pir/core/builtin_dialect.h" #include "paddle/pir/pass/pass_manager.h" -class RemoveRedundentReshapePattern - : public paddle::drr::DrrPatternBase { +class RemoveRedundentReshapePattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { // Source patterns @@ -42,10 +41,11 @@ class RemoveRedundentReshapePattern res.Op("pd_op.reshape")({&res.Tensor("arg0"), &res.Tensor("shape1")}, {&res.Tensor("ret"), &res.Tensor("xshape_1")}); } + + std::string name() const override { return "RemoveRedundentReshapePattern"; } }; -class FoldExpandToConstantPattern - : public paddle::drr::DrrPatternBase { +class FoldExpandToConstantPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { // Source Pattern @@ -79,10 +79,11 @@ class FoldExpandToConstantPattern {"place", pat.Attr("place_1")}}); res.Tensor("ret") = full2(); } + + std::string name() const override { return "FoldExpandToConstantPattern"; } }; -class RemoveRedundentTransposePattern - : public paddle::drr::DrrPatternBase { +class RemoveRedundentTransposePattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern pat = ctx->SourcePattern(); @@ -109,10 +110,13 @@ class RemoveRedundentTransposePattern res.Tensor("ret") = tranpose_continuous(res.Tensor("arg_transpose")); } + + std::string name() const override { + return "RemoveRedundentTransposePattern"; + } }; -class RemoveRedundentCastPattern - : public paddle::drr::DrrPatternBase { +class RemoveRedundentCastPattern : public paddle::drr::DrrPatternBase { void operator()(paddle::drr::DrrPatternContext *ctx) const override { auto pat = ctx->SourcePattern(); pat.Tensor("tmp") = pat.Op( @@ -123,10 +127,11 @@ class RemoveRedundentCastPattern res.Tensor("ret") = res.Op( "pd_op.cast", {{"dtype", pat.Attr("dtype2")}})(res.Tensor("arg0")); } + + std::string name() const override { return "RemoveRedundentCastPattern"; } }; -class RemoveUselessCastPattern - : public paddle::drr::DrrPatternBase { +class RemoveUselessCastPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { auto pat = ctx->SourcePattern(); @@ -135,6 +140,8 @@ class RemoveUselessCastPattern auto res = pat.ResultPattern(); res.Tensor("ret").Assign(res.Tensor("arg0")); } + + std::string name() const override { return "RemoveUselessCastPattern"; } }; void BuildProgram(pir::Builder &builder) { // NOLINT