Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,12 +173,14 @@ void IfOp::Print(pir::IrPrinter &printer) {
printer.AddIndentation();
for (auto &item : true_block()) {
printer.PrintOperation(&item);
os << "\n";
}
printer.DecreaseIndentation();
os << printer.indentation() << "} else {\n";
printer.AddIndentation();
for (auto &item : false_block()) {
printer.PrintOperation(&item);
os << "\n";
}
printer.DecreaseIndentation();
os << printer.indentation() << "}";
Expand Down Expand Up @@ -371,6 +373,7 @@ void WhileOp::Print(pir::IrPrinter &printer) {
printer.AddIndentation();
for (auto &item : body()) {
printer.PrintOperation(&item);
os << "\n";
}
printer.DecreaseIndentation();
os << printer.indentation() << "}";
Expand Down
9 changes: 4 additions & 5 deletions paddle/fluid/pir/transforms/shape_optimization_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "paddle/fluid/pir/transforms/shape_optimization_pass.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/pir/core/dialect.h"
#include "paddle/pir/dialect/shape/ir/shape_attribute.h"
#include "paddle/pir/pass/pass_manager.h"
#include "paddle/pir/pass/pass_registry.h"
Expand All @@ -25,12 +26,10 @@ using PassPipelineRunner =
std::function<bool(pir::PassManager&, pir::ModuleOp)>;

void PrintProgram(pir::ModuleOp m, std::string mgs) {
std::ostringstream print_stream;
print_stream << "\n\n";
m.program()->Print(print_stream);
print_stream << "\n\n";
ShapeConstraintIRAnalysis& shape_analysis =
ShapeAnalysisManager::Instance().Get(m.program());
VLOG(3) << "===================== " << mgs << " =====================\n"
<< print_stream.str();
<< pir::CustomPrintHelper(*m.program(), shape_analysis.PrintHook());
}

void DebugPrintOpInfo(
Expand Down
4 changes: 1 addition & 3 deletions paddle/pir/core/ir_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,14 +168,11 @@ void IrPrinter::PrintOperation(Operation* op) {
if (auto* dialect = op->dialect()) {
if (auto print_fn = dialect->PrintOperation(op)) {
print_fn(op, *this);
os << newline;
return;
}
}

PrintGeneralOperation(op);

os << newline;
}

void IrPrinter::PrintOperationWithNoRegion(Operation* op) {
Expand Down Expand Up @@ -221,6 +218,7 @@ void IrPrinter::PrintBlock(const Block& block) {
AddIndentation();
for (auto& item : block) {
PrintOperation(&item);
os << "\n";
}
DecreaseIndentation();
os << indentation() << "}\n";
Expand Down
28 changes: 1 addition & 27 deletions paddle/pir/dialect/shape/ir/shape_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,33 +29,7 @@ void ShapeDialect::initialize() {
}

void ShapeDialect::PrintAttribute(pir::Attribute attr, std::ostream &os) const {
if (attr.isa<SymbolAttribute>()) {
SymbolAttribute symbol_attr = attr.dyn_cast<SymbolAttribute>();
if (symbol_attr.data().isa<symbol::TensorListShapeOrDataDimExprs>()) return;
os << "(shape_data)";
os << "[";
for (size_t i = 0; i < symbol_attr.data().shape().size(); ++i) {
if (i != symbol_attr.data().shape().size() - 1) {
os << symbol::ToString(symbol_attr.data().shape()[i]) << ",";
} else {
os << symbol::ToString(symbol_attr.data().shape()[i]);
}
}
os << "]_[";
if (symbol_attr.data().data().has_value()) {
for (size_t i = 0; i < symbol_attr.data().data().value().size(); ++i) {
if (i != symbol_attr.data().data().value().size() - 1) {
os << symbol::ToString(symbol_attr.data().data().value()[i]) << ",";
} else {
os << symbol::ToString(symbol_attr.data().data().value()[i]);
}
}
} else {
os << "nullopt";
}

os << "]";
}
return;
}

} // namespace pir::shape
Expand Down
28 changes: 20 additions & 8 deletions paddle/pir/dialect/shape/utils/shape_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,6 @@ static std::string GetValueId(Value val) {

ShapeConstraintIRAnalysis::ShapeConstraintIRAnalysis(ModuleOp m) : m_(m) {}

ShapeConstraintIRAnalysis::ShapeConstraintIRAnalysis(
std::shared_ptr<pir::Program>&& program)
: ShapeConstraintIRAnalysis(program->module_op()) {
program_ = std::move(program);
}
ShapeConstraintIRAnalysis::ShapeConstraintIRAnalysis(pir::IrContext* ctx)
: ShapeConstraintIRAnalysis(std::make_shared<pir::Program>(ctx)) {}

void ShapeConstraintIRAnalysis::Init() {
value_to_shape_or_data_.clear();
next_sym_idx_ = 0;
Expand Down Expand Up @@ -216,6 +208,26 @@ bool ShapeConstraintIRAnalysis::IsSameNumel(Value lhs, Value rhs) const {
static_cast<int>(rhs_type.GetRank()));
}

pir::PrintHooks ShapeConstraintIRAnalysis::PrintHook() const {
pir::PrintHooks print_hook;
print_hook.op_print_hook = [&](Operation* op, IrPrinter& printer) {
printer.IrPrinter::PrintOperation(op);
printer.os << " { ";
for (uint32_t i = 0; i < op->num_results(); ++i) {
if (this->HasShapeOrDataForValue(op->result(i))) {
printer.os << "(" << this->GetShapeOrDataForValue(op->result(i)) << ")";
} else {
printer.os << "()";
}
if (i < op->num_results() - 1) {
printer.os << ", ";
}
}
printer.os << " }";
};
return print_hook;
}

ShapeAnalysisManager& ShapeAnalysisManager::Instance() {
static ShapeAnalysisManager instance;
return instance;
Expand Down
7 changes: 2 additions & 5 deletions paddle/pir/dialect/shape/utils/shape_analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,6 @@ class IR_API ShapeConstraintIRAnalysis {
public:
explicit ShapeConstraintIRAnalysis(ModuleOp m);

explicit ShapeConstraintIRAnalysis(std::shared_ptr<pir::Program>&& program);

explicit ShapeConstraintIRAnalysis(pir::IrContext* ctx);

void Init();

const std::string GetNextSymName();
Expand Down Expand Up @@ -77,9 +73,10 @@ class IR_API ShapeConstraintIRAnalysis {
// Returns true if the two value have the same number elements.
bool IsSameNumel(Value lhs, Value rhs) const;

pir::PrintHooks PrintHook() const;

private:
ModuleOp m_;
std::shared_ptr<pir::Program> program_;

int64_t next_sym_idx_ = 0;

Expand Down
1 change: 0 additions & 1 deletion test/cpp/pir/cinn/adt/map_expr_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ TEST(MapExpr, ElementWise_Fusion_0) {
value1, builder.Build<paddle::dialect::ExpOp>(value2).result(0));

::pir::PassManager pass_manager(ctx);
auto shape_analysis = std::make_shared<pir::ShapeConstraintIRAnalysis>(ctx);

// TODO(@jiahy0825): use CreateShapeOptimizationPass() instead of
// CreateInferSymbolicShapePass() which is a fake pass
Expand Down