Skip to content

Commit 7129945

Browse files
authored
Fix ShapeOrDataDimExpr simplify unwork (#62376)
* update test case * fix * fix concat op infer symbolic * fix some bugs * fix some bugs * fix some bugs * fix some bugs * fix some bugs
1 parent 660276a commit 7129945

3 files changed

Lines changed: 61 additions & 32 deletions

File tree

paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,6 @@ void ApplyCinnPreprocessPass(
9494
if (has_dynamic_shape) {
9595
pass_manager->AddPass(cinn::dialect::ir::CreateConvert0DTo1DPass());
9696
pass_manager->AddPass(pir::CreateShapeOptimizationPass());
97-
pass_manager->AddPass(cinn::dialect::ir::CreateSimplifyDimExprPass());
98-
pass_manager->AddPass(
99-
cinn::dialect::ir::CreateSubstituteDimExprBasedOnConstraintsPass());
10097
pass_manager->AddPass(cinn::dialect::ir::CreateConvert0DTo1DPass());
10198
pass_manager->AddPass(
10299
cinn::dialect::ir::CreateFuseShapeOpsIntoGenerateShapeOpPass());
@@ -130,6 +127,9 @@ void ApplyGroupOpPass(::pir::Program* program,
130127
cinn::dialect::ir::CreateFuseShapeOpsIntoGenerateShapeOpPass());
131128
pass_manager->AddPass(
132129
cinn::dialect::ir::CreateMoveGenerateShapeOpsToProloguePass());
130+
pass_manager->AddPass(
131+
cinn::dialect::ir::CreateSubstituteDimExprBasedOnConstraintsPass());
132+
pass_manager->AddPass(cinn::dialect::ir::CreateSimplifyDimExprPass());
133133
}
134134

135135
pass_manager->AddPass(cinn::dialect::ir::CreateDynamicReshapeOpPass());

paddle/cinn/hlir/dialect/operator/transforms/group_merge/simplify_dim_expr_pass.cc

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,14 @@ namespace ir {
2828
namespace {
2929

3030
template <typename DoEachT>
31-
void VisitEachOp(pir::ModuleOp module_op, const DoEachT& DoEach) {
32-
for (uint32_t i = 0; i < module_op->num_regions(); i++) {
33-
for (pir::Block& block : module_op->region(i)) {
34-
for (pir::Operation& op : block) {
35-
DoEach(op);
31+
void VisitEachOp(pir::Operation* op, const DoEachT& DoEach) {
32+
for (uint32_t i = 0; i < op->num_regions(); i++) {
33+
for (pir::Block& block : op->region(i)) {
34+
for (pir::Operation& sub_op : block) {
35+
DoEach(sub_op);
36+
if (sub_op.num_regions() > 0) {
37+
VisitEachOp(&sub_op, DoEach);
38+
}
3639
}
3740
}
3841
}
@@ -90,24 +93,36 @@ symbol::ShapeOrDataDimExprs SimplifyShapeOrData(
9093
return std::visit(lambdas, shape_or_data.variant());
9194
}
9295

93-
void SimplifyDimExpr(pir::ModuleOp module_op) {
96+
void SimplifyDimExpr(pir::Operation* module_op) {
9497
VLOG(4) << "SimplifyDimExpr start";
95-
pir::ShapeConstraintIRAnalysis shape_analysis =
96-
pir::ShapeAnalysisManager::Instance().Get(module_op.program());
98+
pir::ShapeConstraintIRAnalysis* shape_analysis =
99+
&pir::ShapeAnalysisManager::Instance().Get(
100+
module_op->dyn_cast<pir::ModuleOp>().program());
101+
97102
VisitEachOp(module_op, [&](pir::Operation& op) {
98103
VisitEachValue(op, [&](pir::Value value) {
99-
if (!shape_analysis.HasShapeOrDataForValue(value)) {
104+
if (!shape_analysis->HasShapeOrDataForValue(value)) {
100105
VLOG(4) << "SimplifyDimExpr: shape_analysis can't find ShapeOrData for "
101106
"value of the op:"
102107
<< op.name();
103108
} else {
104109
const symbol::ShapeOrDataDimExprs& shape_or_data =
105-
shape_analysis.GetShapeOrDataForValue(value);
110+
shape_analysis->GetShapeOrDataForValue(value);
111+
VLOG(8) << op.name() << " origin_shape_or_data: " << shape_or_data;
106112
symbol::ShapeOrDataDimExprs simplified_shape_or_data =
107113
SimplifyShapeOrData(shape_or_data);
108-
shape_analysis.SetShapeOrDataForValue(value, simplified_shape_or_data);
114+
VLOG(8) << op.name()
115+
<< " simplified_shape_or_data: " << simplified_shape_or_data;
116+
shape_analysis->SetShapeOrDataForValue(value, simplified_shape_or_data);
109117
}
110118
});
119+
if (op.num_results() > 0) {
120+
pir::shape::SetShapeAttrForOp(
121+
&op, shape_analysis->GetShapeOrDataForValue(op.result(0)));
122+
} else {
123+
pir::shape::SetShapeAttrForOp(
124+
&op, shape_analysis->GetShapeOrDataForValue(op.operand_source(0)));
125+
}
111126
// TODO(JiaWenxuan): simplify the attribute "sym_shape_str" of the op
112127
});
113128
VLOG(4) << "SimplifyDimExpr end";
@@ -117,10 +132,7 @@ class SimplifyDimExprPass : public pir::Pass {
117132
public:
118133
SimplifyDimExprPass() : pir::Pass("simplify_dim_expr_pass", 1) {}
119134

120-
void Run(pir::Operation* op) override {
121-
pir::ModuleOp module_op = op->dyn_cast<pir::ModuleOp>();
122-
SimplifyDimExpr(module_op);
123-
}
135+
void Run(pir::Operation* op) override { SimplifyDimExpr(op); }
124136

125137
bool CanApplyOn(pir::Operation* op) const override {
126138
return op->isa<pir::ModuleOp>() && op->num_regions() > 0;

paddle/cinn/hlir/dialect/operator/transforms/group_merge/substitute_dim_expr_based_on_constraints_pass.cc

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
#include "paddle/cinn/common/dim_expr_util.h"
2020
#include "paddle/cinn/common/union_find.h"
21+
#include "paddle/pir/include/dialect/shape/ir/shape_attribute.h"
2122

2223
namespace cinn {
2324
namespace dialect {
@@ -26,11 +27,14 @@ namespace ir {
2627
namespace {
2728

2829
template <typename DoEachT>
29-
void VisitEachOp(pir::ModuleOp module_op, const DoEachT& DoEach) {
30-
for (uint32_t i = 0; i < module_op->num_regions(); i++) {
31-
for (pir::Block& block : module_op->region(i)) {
32-
for (pir::Operation& op : block) {
33-
DoEach(op);
30+
void VisitEachOp(pir::Operation* op, const DoEachT& DoEach) {
31+
for (uint32_t i = 0; i < op->num_regions(); i++) {
32+
for (pir::Block& block : op->region(i)) {
33+
for (pir::Operation& sub_op : block) {
34+
DoEach(sub_op);
35+
if (sub_op.num_regions() > 0) {
36+
VisitEachOp(&sub_op, DoEach);
37+
}
3438
}
3539
}
3640
}
@@ -133,25 +137,39 @@ std::unordered_map<symbol::DimExpr, symbol::DimExpr> GetDimExprSubstitution(
133137
return substitution_pattern;
134138
}
135139

136-
void SubstituteDimExprBasedOnConstraints(pir::ModuleOp module_op) {
140+
void SubstituteDimExprBasedOnConstraints(pir::Operation* module_op) {
137141
VLOG(4) << "SubstituteDimExprBasedOnConstraints start";
138-
pir::ShapeConstraintIRAnalysis shape_analysis =
139-
pir::ShapeAnalysisManager::Instance().Get(module_op.program());
142+
pir::ShapeConstraintIRAnalysis* shape_analysis =
143+
&pir::ShapeAnalysisManager::Instance().Get(
144+
module_op->dyn_cast<pir::ModuleOp>().program());
140145
const std::unordered_map<symbol::DimExpr, symbol::DimExpr>&
141-
substitution_pattern = GetDimExprSubstitution(&shape_analysis);
146+
substitution_pattern = GetDimExprSubstitution(shape_analysis);
147+
142148
VisitEachOp(module_op, [&](pir::Operation& op) {
143149
VisitEachValue(op, [&](pir::Value value) {
144-
if (!shape_analysis.HasShapeOrDataForValue(value)) {
150+
if (!shape_analysis->HasShapeOrDataForValue(value)) {
145151
VLOG(4) << "Can not find ShapeOrData for value of op(" << op.name()
146152
<< ") in shape_analysis";
147153
} else {
148154
const symbol::ShapeOrDataDimExprs& origin_shape_or_data =
149-
shape_analysis.GetShapeOrDataForValue(value);
155+
shape_analysis->GetShapeOrDataForValue(value);
156+
VLOG(8) << op.name()
157+
<< " origin_shape_or_data: " << origin_shape_or_data;
150158
const symbol::ShapeOrDataDimExprs& substituted_shape_or_data =
151159
SubstituteShapeOrData(origin_shape_or_data, substitution_pattern);
152-
shape_analysis.SetShapeOrDataForValue(value, substituted_shape_or_data);
160+
VLOG(8) << op.name()
161+
<< " substituted_shape_or_data: " << substituted_shape_or_data;
162+
shape_analysis->SetShapeOrDataForValue(value,
163+
substituted_shape_or_data);
153164
}
154165
});
166+
if (op.num_results() > 0) {
167+
pir::shape::SetShapeAttrForOp(
168+
&op, shape_analysis->GetShapeOrDataForValue(op.result(0)));
169+
} else {
170+
pir::shape::SetShapeAttrForOp(
171+
&op, shape_analysis->GetShapeOrDataForValue(op.operand_source(0)));
172+
}
155173
// TODO(JiaWenxuan): substitute the attribute "sym_shape_str" of the op
156174
});
157175
VLOG(4) << "SubstituteDimExprBasedOnConstraints end";
@@ -163,8 +181,7 @@ class SubstituteDimExprBasedOnConstraintsPass : public pir::Pass {
163181
: pir::Pass("substitute_dim_expr_based_on_constraints_pass", 1) {}
164182

165183
void Run(pir::Operation* op) override {
166-
pir::ModuleOp module_op = op->dyn_cast<pir::ModuleOp>();
167-
SubstituteDimExprBasedOnConstraints(module_op);
184+
SubstituteDimExprBasedOnConstraints(op);
168185
}
169186

170187
bool CanApplyOn(pir::Operation* op) const override {

0 commit comments

Comments
 (0)