diff --git a/paddle/cinn/hlir/dialect/operator/transforms/add_store_in_fusion_op_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/add_store_in_fusion_op_pass.cc index 7e53d9eda32efc..c8be16a19240cf 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/add_store_in_fusion_op_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/add_store_in_fusion_op_pass.cc @@ -32,6 +32,8 @@ class AddYieldStoreInFusionOpPattern bool MatchAndRewrite(::pir::YieldOp op, pir::PatternRewriter& rewriter) const override { + auto& shape_analysis = + pir::ShapeAnalysisManager::Instance().Get(op->GetParentProgram()); for (auto i = 0; i < op->num_operands(); ++i) { if (auto reshape_op = op->operand_source(i) .defining_op() @@ -44,11 +46,17 @@ class AddYieldStoreInFusionOpPattern if ((pre_name != "cinn_op.reduce_sum") && (pre_name != "cinn_op.reduce_max")) { - auto new_full = rewriter.Build( + auto store_op = rewriter.Build( op->operand_source(i).defining_op()->operand_source(0), op->operand_source(i).type()); - op->operand(i).set_source(new_full.result(0)); + if (shape_analysis.HasShapeOrDataForValue(reshape_op->result(0))) { + shape_analysis.SetShapeOrDataForValue( + store_op.result(0), + shape_analysis.GetShapeOrDataForValue(reshape_op->result(0))); + } + + op->operand(i).set_source(store_op.result(0)); if (reshape_op->result(0).use_count() == 0) { rewriter.EraseOp(reshape_op); } @@ -60,10 +68,16 @@ class AddYieldStoreInFusionOpPattern continue; } - auto new_full = rewriter.Build( + auto store_op = rewriter.Build( op->operand_source(i), op->operand_source(i).type()); + auto orignal_base = op->operand_source(i); + op->operand(i).set_source(store_op.result(0)); - op->operand(i).set_source(new_full.result(0)); + if (shape_analysis.HasShapeOrDataForValue(orignal_base)) { + shape_analysis.SetShapeOrDataForValue( + store_op.result(0), + shape_analysis.GetShapeOrDataForValue(orignal_base)); + } } return true;