diff --git a/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc index 25d0448848b18f..3b6b1adcdbda14 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc @@ -23,6 +23,7 @@ #include "paddle/pir/include/dialect/shape/ir/shape_dialect.h" #include "paddle/pir/include/pass/pass_manager.h" +#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" #include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h" #include "paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.h" #include "paddle/cinn/hlir/dialect/operator/transforms/add_store_in_fusion_op_pass.h" @@ -74,6 +75,16 @@ bool HasDynamicShape(const pir::Program& program) { } } // namespace +void ApplyPdToCinnPass( + ::pir::Program* program, + const std::function()>& + CreatePassManager) { + std::shared_ptr pass_manager = CreatePassManager(); + pass_manager->AddPass(cinn::dialect::ir::CreatePdOpToCinnOpPass()); + pass_manager->AddPass(pir::CreateDeadCodeEliminationPass()); + pass_manager->Run(program); +} + void ApplyCinnPreprocessPass( ::pir::Program* program, const std::function()>& @@ -81,16 +92,10 @@ void ApplyCinnPreprocessPass( std::shared_ptr pass_manager = CreatePassManager(); bool has_dynamic_shape = HasDynamicShape(*program); - pass_manager->AddPass(cinn::dialect::ir::CreateConvert0DTo1DPass()); if (!has_dynamic_shape && FLAGS_check_infer_symbolic) { pass_manager->AddPass(pir::CreateShapeOptimizationPass()); pass_manager->AddPass(cinn::dialect::ir::CreateCheckInferSymbolicPass()); } - pass_manager->AddPass(cinn::dialect::ir::CreatePdOpToCinnOpPass()); - - pass_manager->AddPass( - cinn::dialect::ir::CreateAddBroadcastToElementwisePass()); - pass_manager->AddPass(pir::CreateDeadCodeEliminationPass()); if (has_dynamic_shape) { pass_manager->AddPass(cinn::dialect::ir::CreateConvert0DTo1DPass()); @@ -124,6 +129,8 @@ void ApplyGroupOpPass(::pir::Program* program, const std::function()>& CreatePassManager) { std::shared_ptr pass_manager = CreatePassManager(); + pass_manager->AddPass( + cinn::dialect::ir::CreateAddBroadcastToElementwisePass()); if (HasDynamicShape(*program)) { pass_manager->AddPass(::pir::CreateShapeOptimizationPass()); pass_manager->AddPass(cinn::dialect::ir::CreateInsertBroadcastPass()); @@ -188,13 +195,36 @@ void ApplyCinnLowerPass( pass_manager->Run(program); } +template +int64_t GetOpCount(const ::pir::Operation* op) { + int64_t count = 0; + for (auto& region : *op) { + for (auto& block : region) { + for (auto& sub_op : block) { + if (sub_op.isa()) { + count++; + continue; + } + if (sub_op.num_regions() > 0) { + count += GetOpCount(&sub_op); + } + } + } + } + return count; +} + void ApplyCinnPass(::pir::Program* program, const std::function()>& CreatePassManager) { + ApplyPdToCinnPass(program, CreatePassManager); ApplyCinnPreprocessPass(program, CreatePassManager); ApplyBuildGroupOpPass(program, CreatePassManager); ApplyGroupOpPass(program, CreatePassManager); ApplyDivideGroupOpToFusionOpPass(program, CreatePassManager); + LOG(INFO) << "FusionOp count before lowering : *****[ " + << GetOpCount(program->module_op()) + << " ]*****"; ApplyCinnLowerPass(program, CreatePassManager); }