From f4d18929b2fef0579a49471a572f659b4ba63215 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Tue, 2 Apr 2024 11:22:02 +0000 Subject: [PATCH] add infer_symbol_shape for yield_store and remove tricky code of yield_store pass --- .../hlir/dialect/operator/ir/manual_op.cc | 7 ++++ .../cinn/hlir/dialect/operator/ir/manual_op.h | 6 +++- .../transforms/add_store_in_fusion_op_pass.cc | 32 ------------------- 3 files changed, 12 insertions(+), 33 deletions(-) diff --git a/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc b/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc index 71f0b9f33f4ec9..2dbe30c4447b77 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc +++ b/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc @@ -192,6 +192,13 @@ void YieldStoreOp::Build(pir::Builder& builder, void YieldStoreOp::VerifySig() {} +bool YieldStoreOp::InferSymbolicShape( + pir::ShapeConstraintIRAnalysis* shape_analysis) { + shape_analysis->SetShapeOrDataForValue( + result(0), shape_analysis->GetShapeOrDataForValue(operand_source(0))); + return true; +} + bool ConcatOp::InferSymbolicShape( pir::ShapeConstraintIRAnalysis* shape_analysis) { VLOG(4) << "Infer symbolic shape for cinn_op.concat"; diff --git a/paddle/cinn/hlir/dialect/operator/ir/manual_op.h b/paddle/cinn/hlir/dialect/operator/ir/manual_op.h index d350cbb3d5208f..f27908438d3b96 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/manual_op.h +++ b/paddle/cinn/hlir/dialect/operator/ir/manual_op.h @@ -86,7 +86,9 @@ class IR_API FusionOp : public pir::Op { // YieldStoreOp represents a store operation for // seperate local variable and ouptut -class IR_API YieldStoreOp : public pir::Op { +class IR_API YieldStoreOp + : public pir::Op { public: using Op::Op; static const char *name() { return "cinn_op.yield_store"; } @@ -98,6 +100,8 @@ class IR_API YieldStoreOp : public pir::Op { pir::Type output_type); void VerifySig(); + + bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis *shape_analysis); }; class IR_API ConcatOp diff --git a/paddle/cinn/hlir/dialect/operator/transforms/add_store_in_fusion_op_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/add_store_in_fusion_op_pass.cc index 143f72985a3bfd..e0c52169df0a6e 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/add_store_in_fusion_op_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/add_store_in_fusion_op_pass.cc @@ -35,38 +35,6 @@ class AddYieldStoreInFusionOpPattern auto& shape_analysis = pir::ShapeAnalysisManager::Instance().Get(op->GetParentProgram()); for (auto i = 0; i < op->num_operands(); ++i) { - if (auto reshape_op = op->operand_source(i) - .defining_op() - ->dyn_cast()) { - if (reshape_op.operand_source(0).defining_op() == nullptr) { - continue; - } - auto pre_name = reshape_op.operand_source(0).defining_op()->name(); - - if (op->operand_source(i).use_count() > 1) { - continue; - } - - if ((pre_name != "cinn_op.reduce_sum") && - (pre_name != "cinn_op.reduce_max")) { - auto store_op = rewriter.Build( - op->operand_source(i).defining_op()->operand_source(0), - op->operand_source(i).type()); - - if (shape_analysis.HasShapeOrDataForValue(reshape_op->result(0))) { - shape_analysis.SetShapeOrDataForValue( - store_op.result(0), - shape_analysis.GetShapeOrDataForValue(reshape_op->result(0))); - } - - op->operand(i).set_source(store_op.result(0)); - if (reshape_op->result(0).use_count() == 0) { - rewriter.EraseOp(reshape_op); - } - continue; - } - } - if (op->operand_source(i).use_count() == 1) { continue; }