1818
1919#include " paddle/cinn/common/dim_expr_util.h"
2020#include " paddle/cinn/common/union_find.h"
21+ #include " paddle/pir/include/dialect/shape/ir/shape_attribute.h"
2122
2223namespace cinn {
2324namespace dialect {
@@ -26,11 +27,14 @@ namespace ir {
2627namespace {
2728
2829template <typename DoEachT>
29- void VisitEachOp (pir::ModuleOp module_op, const DoEachT& DoEach) {
30- for (uint32_t i = 0 ; i < module_op->num_regions (); i++) {
31- for (pir::Block& block : module_op->region (i)) {
32- for (pir::Operation& op : block) {
33- DoEach (op);
30+ void VisitEachOp (pir::Operation* op, const DoEachT& DoEach) {
31+ for (uint32_t i = 0 ; i < op->num_regions (); i++) {
32+ for (pir::Block& block : op->region (i)) {
33+ for (pir::Operation& sub_op : block) {
34+ DoEach (sub_op);
35+ if (sub_op.num_regions () > 0 ) {
36+ VisitEachOp (&sub_op, DoEach);
37+ }
3438 }
3539 }
3640 }
@@ -133,25 +137,39 @@ std::unordered_map<symbol::DimExpr, symbol::DimExpr> GetDimExprSubstitution(
133137 return substitution_pattern;
134138}
135139
136- void SubstituteDimExprBasedOnConstraints (pir::ModuleOp module_op) {
140+ void SubstituteDimExprBasedOnConstraints (pir::Operation* module_op) {
137141 VLOG (4 ) << " SubstituteDimExprBasedOnConstraints start" ;
138- pir::ShapeConstraintIRAnalysis shape_analysis =
139- pir::ShapeAnalysisManager::Instance ().Get (module_op.program ());
142+ pir::ShapeConstraintIRAnalysis* shape_analysis =
143+ &pir::ShapeAnalysisManager::Instance ().Get (
144+ module_op->dyn_cast <pir::ModuleOp>().program ());
140145 const std::unordered_map<symbol::DimExpr, symbol::DimExpr>&
141- substitution_pattern = GetDimExprSubstitution (&shape_analysis);
146+ substitution_pattern = GetDimExprSubstitution (shape_analysis);
147+
142148 VisitEachOp (module_op, [&](pir::Operation& op) {
143149 VisitEachValue (op, [&](pir::Value value) {
144- if (!shape_analysis. HasShapeOrDataForValue (value)) {
150+ if (!shape_analysis-> HasShapeOrDataForValue (value)) {
145151 VLOG (4 ) << " Can not find ShapeOrData for value of op(" << op.name ()
146152 << " ) in shape_analysis" ;
147153 } else {
148154 const symbol::ShapeOrDataDimExprs& origin_shape_or_data =
149- shape_analysis.GetShapeOrDataForValue (value);
155+ shape_analysis->GetShapeOrDataForValue (value);
156+ VLOG (8 ) << op.name ()
157+ << " origin_shape_or_data: " << origin_shape_or_data;
150158 const symbol::ShapeOrDataDimExprs& substituted_shape_or_data =
151159 SubstituteShapeOrData (origin_shape_or_data, substitution_pattern);
152- shape_analysis.SetShapeOrDataForValue (value, substituted_shape_or_data);
160+ VLOG (8 ) << op.name ()
161+ << " substituted_shape_or_data: " << substituted_shape_or_data;
162+ shape_analysis->SetShapeOrDataForValue (value,
163+ substituted_shape_or_data);
153164 }
154165 });
166+ if (op.num_results () > 0 ) {
167+ pir::shape::SetShapeAttrForOp (
168+ &op, shape_analysis->GetShapeOrDataForValue (op.result (0 )));
169+ } else {
170+ pir::shape::SetShapeAttrForOp (
171+ &op, shape_analysis->GetShapeOrDataForValue (op.operand_source (0 )));
172+ }
155173 // TODO(JiaWenxuan): substitute the attribute "sym_shape_str" of the op
156174 });
157175 VLOG (4 ) << " SubstituteDimExprBasedOnConstraints end" ;
@@ -163,8 +181,7 @@ class SubstituteDimExprBasedOnConstraintsPass : public pir::Pass {
163181 : pir::Pass(" substitute_dim_expr_based_on_constraints_pass" , 1 ) {}
164182
165183 void Run (pir::Operation* op) override {
166- pir::ModuleOp module_op = op->dyn_cast <pir::ModuleOp>();
167- SubstituteDimExprBasedOnConstraints (module_op);
184+ SubstituteDimExprBasedOnConstraints (op);
168185 }
169186
170187 bool CanApplyOn (pir::Operation* op) const override {
0 commit comments