diff --git a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc index 7ebd1754e06887..e4e4f9746e08c7 100644 --- a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc @@ -173,12 +173,14 @@ void IfOp::Print(pir::IrPrinter &printer) { printer.AddIndentation(); for (auto &item : true_block()) { printer.PrintOperation(&item); + os << "\n"; } printer.DecreaseIndentation(); os << printer.indentation() << "} else {\n"; printer.AddIndentation(); for (auto &item : false_block()) { printer.PrintOperation(&item); + os << "\n"; } printer.DecreaseIndentation(); os << printer.indentation() << "}"; @@ -371,6 +373,7 @@ void WhileOp::Print(pir::IrPrinter &printer) { printer.AddIndentation(); for (auto &item : body()) { printer.PrintOperation(&item); + os << "\n"; } printer.DecreaseIndentation(); os << printer.indentation() << "}"; diff --git a/paddle/fluid/pir/transforms/shape_optimization_pass.cc b/paddle/fluid/pir/transforms/shape_optimization_pass.cc index cb2d0467e4e088..a239a0a78ace4a 100644 --- a/paddle/fluid/pir/transforms/shape_optimization_pass.cc +++ b/paddle/fluid/pir/transforms/shape_optimization_pass.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/pir/transforms/shape_optimization_pass.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/pir/core/dialect.h" #include "paddle/pir/dialect/shape/ir/shape_attribute.h" #include "paddle/pir/pass/pass_manager.h" #include "paddle/pir/pass/pass_registry.h" @@ -25,12 +26,10 @@ using PassPipelineRunner = std::function; void PrintProgram(pir::ModuleOp m, std::string mgs) { - std::ostringstream print_stream; - print_stream << "\n\n"; - m.program()->Print(print_stream); - print_stream << "\n\n"; + ShapeConstraintIRAnalysis& shape_analysis = + ShapeAnalysisManager::Instance().Get(m.program()); VLOG(3) << "===================== " << mgs << " =====================\n" - << print_stream.str(); + << pir::CustomPrintHelper(*m.program(), shape_analysis.PrintHook()); } void DebugPrintOpInfo( diff --git a/paddle/pir/core/ir_printer.cc b/paddle/pir/core/ir_printer.cc index fc7905d79c9c86..354e40cd671bea 100644 --- a/paddle/pir/core/ir_printer.cc +++ b/paddle/pir/core/ir_printer.cc @@ -168,14 +168,11 @@ void IrPrinter::PrintOperation(Operation* op) { if (auto* dialect = op->dialect()) { if (auto print_fn = dialect->PrintOperation(op)) { print_fn(op, *this); - os << newline; return; } } PrintGeneralOperation(op); - - os << newline; } void IrPrinter::PrintOperationWithNoRegion(Operation* op) { @@ -221,6 +218,7 @@ void IrPrinter::PrintBlock(const Block& block) { AddIndentation(); for (auto& item : block) { PrintOperation(&item); + os << "\n"; } DecreaseIndentation(); os << indentation() << "}\n"; diff --git a/paddle/pir/dialect/shape/ir/shape_dialect.cc b/paddle/pir/dialect/shape/ir/shape_dialect.cc index 9de1a6b2054016..a8c1cef5490d05 100644 --- a/paddle/pir/dialect/shape/ir/shape_dialect.cc +++ b/paddle/pir/dialect/shape/ir/shape_dialect.cc @@ -29,33 +29,7 @@ void ShapeDialect::initialize() { } void ShapeDialect::PrintAttribute(pir::Attribute attr, std::ostream &os) const { - if (attr.isa()) { - SymbolAttribute symbol_attr = attr.dyn_cast(); - if (symbol_attr.data().isa()) return; - os << "(shape_data)"; - os << "["; - for (size_t i = 0; i < symbol_attr.data().shape().size(); ++i) { - if (i != symbol_attr.data().shape().size() - 1) { - os << symbol::ToString(symbol_attr.data().shape()[i]) << ","; - } else { - os << symbol::ToString(symbol_attr.data().shape()[i]); - } - } - os << "]_["; - if (symbol_attr.data().data().has_value()) { - for (size_t i = 0; i < symbol_attr.data().data().value().size(); ++i) { - if (i != symbol_attr.data().data().value().size() - 1) { - os << symbol::ToString(symbol_attr.data().data().value()[i]) << ","; - } else { - os << symbol::ToString(symbol_attr.data().data().value()[i]); - } - } - } else { - os << "nullopt"; - } - - os << "]"; - } + return; } } // namespace pir::shape diff --git a/paddle/pir/dialect/shape/utils/shape_analysis.cc b/paddle/pir/dialect/shape/utils/shape_analysis.cc index 43b629df57b968..6c407c7ac35eb0 100644 --- a/paddle/pir/dialect/shape/utils/shape_analysis.cc +++ b/paddle/pir/dialect/shape/utils/shape_analysis.cc @@ -28,14 +28,6 @@ static std::string GetValueId(Value val) { ShapeConstraintIRAnalysis::ShapeConstraintIRAnalysis(ModuleOp m) : m_(m) {} -ShapeConstraintIRAnalysis::ShapeConstraintIRAnalysis( - std::shared_ptr&& program) - : ShapeConstraintIRAnalysis(program->module_op()) { - program_ = std::move(program); -} -ShapeConstraintIRAnalysis::ShapeConstraintIRAnalysis(pir::IrContext* ctx) - : ShapeConstraintIRAnalysis(std::make_shared(ctx)) {} - void ShapeConstraintIRAnalysis::Init() { value_to_shape_or_data_.clear(); next_sym_idx_ = 0; @@ -216,6 +208,26 @@ bool ShapeConstraintIRAnalysis::IsSameNumel(Value lhs, Value rhs) const { static_cast(rhs_type.GetRank())); } +pir::PrintHooks ShapeConstraintIRAnalysis::PrintHook() const { + pir::PrintHooks print_hook; + print_hook.op_print_hook = [&](Operation* op, IrPrinter& printer) { + printer.IrPrinter::PrintOperation(op); + printer.os << " { "; + for (uint32_t i = 0; i < op->num_results(); ++i) { + if (this->HasShapeOrDataForValue(op->result(i))) { + printer.os << "(" << this->GetShapeOrDataForValue(op->result(i)) << ")"; + } else { + printer.os << "()"; + } + if (i < op->num_results() - 1) { + printer.os << ", "; + } + } + printer.os << " }"; + }; + return print_hook; +} + ShapeAnalysisManager& ShapeAnalysisManager::Instance() { static ShapeAnalysisManager instance; return instance; diff --git a/paddle/pir/dialect/shape/utils/shape_analysis.h b/paddle/pir/dialect/shape/utils/shape_analysis.h index 7d999164732331..d1858d7787de89 100644 --- a/paddle/pir/dialect/shape/utils/shape_analysis.h +++ b/paddle/pir/dialect/shape/utils/shape_analysis.h @@ -30,10 +30,6 @@ class IR_API ShapeConstraintIRAnalysis { public: explicit ShapeConstraintIRAnalysis(ModuleOp m); - explicit ShapeConstraintIRAnalysis(std::shared_ptr&& program); - - explicit ShapeConstraintIRAnalysis(pir::IrContext* ctx); - void Init(); const std::string GetNextSymName(); @@ -77,9 +73,10 @@ class IR_API ShapeConstraintIRAnalysis { // Returns true if the two value have the same number elements. bool IsSameNumel(Value lhs, Value rhs) const; + pir::PrintHooks PrintHook() const; + private: ModuleOp m_; - std::shared_ptr program_; int64_t next_sym_idx_ = 0; diff --git a/test/cpp/pir/cinn/adt/map_expr_test.cc b/test/cpp/pir/cinn/adt/map_expr_test.cc index a21d18ebeb0e3f..527255ef6edf7c 100644 --- a/test/cpp/pir/cinn/adt/map_expr_test.cc +++ b/test/cpp/pir/cinn/adt/map_expr_test.cc @@ -73,7 +73,6 @@ TEST(MapExpr, ElementWise_Fusion_0) { value1, builder.Build(value2).result(0)); ::pir::PassManager pass_manager(ctx); - auto shape_analysis = std::make_shared(ctx); // TODO(@jiahy0825): use CreateShapeOptimizationPass() instead of // CreateInferSymbolicShapePass() which is a fake pass diff --git a/test/cpp/pir/core/add_dialect_parser_test.cc b/test/cpp/pir/core/add_dialect_parser_test.cc index e09338ddb0389d..60efe22fe7ba4b 100644 --- a/test/cpp/pir/core/add_dialect_parser_test.cc +++ b/test/cpp/pir/core/add_dialect_parser_test.cc @@ -102,7 +102,7 @@ TEST(IrParserTest, AddAttribute) { std::string op_str = "(%0) = \"builtin.parameter\" () " "{parameter_name:\"conv2d_0.w_0\",test:(tp.char)a} : () -> " - "pd_op.tensor<64x3x7x7xf32>\n"; + "pd_op.tensor<64x3x7x7xf32>"; std::stringstream ss; ss << op_str; pir::IrParser* parser = new pir::IrParser(ctx, ss); diff --git a/test/cpp/pir/core/ir_op_test.cc b/test/cpp/pir/core/ir_op_test.cc index 98f80c583e10d8..712ffbb8a4415b 100644 --- a/test/cpp/pir/core/ir_op_test.cc +++ b/test/cpp/pir/core/ir_op_test.cc @@ -64,7 +64,7 @@ TEST(op_test, region_test) { // (3) Test custom operation printer std::stringstream ss; op1->Print(ss); - EXPECT_EQ(ss.str(), "(%0) = \"test.operation1\" ()\n"); + EXPECT_EQ(ss.str(), "(%0) = \"test.operation1\" ()"); region.push_back(new pir::Block()); region.push_front(new pir::Block()); diff --git a/test/cpp/pir/core/ir_printer_test.cc b/test/cpp/pir/core/ir_printer_test.cc index f2bc072a6ee2a4..ba84ba71279808 100644 --- a/test/cpp/pir/core/ir_printer_test.cc +++ b/test/cpp/pir/core/ir_printer_test.cc @@ -61,7 +61,6 @@ TEST(printer_test, custom_hooks) { printer.PrintAttributeMap(op); printer.os << " :"; printer.PrintOpReturnType(op); - printer.os << "\n"; }; hooks.attribute_print_hook = [](pir::Attribute attr,