From 9ae058a98381c86eba49dd8539ea678f156c8dde Mon Sep 17 00:00:00 2001 From: zyfncg Date: Thu, 28 Mar 2024 16:58:47 +0000 Subject: [PATCH 1/2] fix bug of pass order --- .../operator/transforms/add_cinn_pass.cc | 46 ++++++++++++++++--- 1 file changed, 40 insertions(+), 6 deletions(-) diff --git a/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc index 25d0448848b18f..0a74b97b919b2b 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc @@ -23,6 +23,7 @@ #include "paddle/pir/include/dialect/shape/ir/shape_dialect.h" #include "paddle/pir/include/pass/pass_manager.h" +#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" #include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h" #include "paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.h" #include "paddle/cinn/hlir/dialect/operator/transforms/add_store_in_fusion_op_pass.h" @@ -74,6 +75,16 @@ bool HasDynamicShape(const pir::Program& program) { } } // namespace +void ApplyPdToCinnPass( + ::pir::Program* program, + const std::function()>& + CreatePassManager) { + std::shared_ptr pass_manager = CreatePassManager(); + pass_manager->AddPass(cinn::dialect::ir::CreatePdOpToCinnOpPass()); + pass_manager->AddPass(pir::CreateDeadCodeEliminationPass()); + pass_manager->Run(program); +} + void ApplyCinnPreprocessPass( ::pir::Program* program, const std::function()>& @@ -81,16 +92,10 @@ void ApplyCinnPreprocessPass( std::shared_ptr pass_manager = CreatePassManager(); bool has_dynamic_shape = HasDynamicShape(*program); - pass_manager->AddPass(cinn::dialect::ir::CreateConvert0DTo1DPass()); if (!has_dynamic_shape && FLAGS_check_infer_symbolic) { pass_manager->AddPass(pir::CreateShapeOptimizationPass()); pass_manager->AddPass(cinn::dialect::ir::CreateCheckInferSymbolicPass()); } - pass_manager->AddPass(cinn::dialect::ir::CreatePdOpToCinnOpPass()); - - pass_manager->AddPass( - cinn::dialect::ir::CreateAddBroadcastToElementwisePass()); - pass_manager->AddPass(pir::CreateDeadCodeEliminationPass()); if (has_dynamic_shape) { pass_manager->AddPass(cinn::dialect::ir::CreateConvert0DTo1DPass()); @@ -124,6 +129,8 @@ void ApplyGroupOpPass(::pir::Program* program, const std::function()>& CreatePassManager) { std::shared_ptr pass_manager = CreatePassManager(); + pass_manager->AddPass( + cinn::dialect::ir::CreateAddBroadcastToElementwisePass()); if (HasDynamicShape(*program)) { pass_manager->AddPass(::pir::CreateShapeOptimizationPass()); pass_manager->AddPass(cinn::dialect::ir::CreateInsertBroadcastPass()); @@ -188,13 +195,40 @@ void ApplyCinnLowerPass( pass_manager->Run(program); } +template +int64_t GetOpCount(const ::pir::Operation* op) { + int64_t count = 0; + for (auto& region : *op) { + for (auto& block : region) { + for (auto& sub_op : block) { + if (sub_op.isa()) { + count++; + continue; + } + if (sub_op.num_regions() > 0) { + count += GetOpCount(&sub_op); + } + } + } + } + return count; +} + +int64_t GetFusionOpCount(const ::pir::Program& program) { + return GetOpCount(program.module_op()); +} + void ApplyCinnPass(::pir::Program* program, const std::function()>& CreatePassManager) { + ApplyPdToCinnPass(program, CreatePassManager); ApplyCinnPreprocessPass(program, CreatePassManager); ApplyBuildGroupOpPass(program, CreatePassManager); ApplyGroupOpPass(program, CreatePassManager); ApplyDivideGroupOpToFusionOpPass(program, CreatePassManager); + LOG(INFO) << "FusionOp count before lowering : *****[ " + << GetOpCount(program->module_op()) + << " ]*****"; ApplyCinnLowerPass(program, CreatePassManager); } From 0594294e1aee466d5ba443840d0ba000cf1d5779 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Thu, 28 Mar 2024 16:59:42 +0000 Subject: [PATCH 2/2] polish code --- paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc | 4 ---- 1 file changed, 4 deletions(-) diff --git a/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc index 0a74b97b919b2b..3b6b1adcdbda14 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc @@ -214,10 +214,6 @@ int64_t GetOpCount(const ::pir::Operation* op) { return count; } -int64_t GetFusionOpCount(const ::pir::Program& program) { - return GetOpCount(program.module_op()); -} - void ApplyCinnPass(::pir::Program* program, const std::function()>& CreatePassManager) {