From 0f4b509f3477f6c961be04abf61f0b17293b10c4 Mon Sep 17 00:00:00 2001 From: fty1777 Date: Mon, 4 Mar 2024 15:30:37 +0000 Subject: [PATCH 01/13] WIP: builtin.split op infer sym shape --- .../pir/dialect/operator/ir/op_dialect.cc | 102 ++++++++++++++++++ 1 file changed, 102 insertions(+) diff --git a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc index 6816d64a054671..9223d59c6fc672 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc +++ b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc @@ -159,6 +159,103 @@ struct ShadowOutputOpInferSymbolicShapeInterfaceModel : InferSymbolicShapeInterface::Concept(InferSymbolicShape) {} }; +struct SplitOpInferSymbolicShapeInterfaceModel + : public InferSymbolicShapeInterface::Concept { + static inline bool InferSymbolicShape( + pir::Operation* op, pir::ShapeConstraintIRAnalysis* shape_analysis) { + VLOG(0) << "InferSymbolicShape of SplitOp!"; + // x + IR_ENFORCE(!op->operand_source(0).data().has_value(), + "Currently InferSymbolicShape of SplitOp only support " + "input without value."); + auto x_shape_or_data = + shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); + std::vector x_dims_sym = x_shape_or_data.shape(); + + // axis + CHECK(op->operand_source(2).defining_op()->isa()); + + int64_t axis = op->operand_source(2) + .defining_op() + .attributes() + .at("value") + .dyn_cast() + .data() + .to(); + + // sections or num + if (op->operand_source(1).defining_op()->isa()) { + // num + VLOG(0) << "Num!"; + int64_t num = op->operand_source(1) + .defining_op() + .attributes() + .at("value") + .dyn_cast() + .data() + .to(); + x_dims_sym[axis] = x_dims_sym[axis] / num; + + // all the result have the same shape + for (uint32_t rst_idx = 0; rst_idx < op->num_results(); rst_idx++) { + shape_analysis->SetShapeOrDataForValue( + op->result(rst_idx), + symbol::ShapeOrDataDimExprs{ + symbol::TensorShapeOrDataDimExprs(x_dims_sym)}); + } + } else { + // sections + VLOG(0) << "Sections!"; + auto sections_shape_or_data = + shape_analysis->GetShapeOrDataForValue(op->operand_source(1)); + std::vector sections_sym; + if (sections_shape_or_data.data().has_value()) { + sections_sym = sections_shape_or_data.data().value(); + } else { + sections_sym = sections_shape_or_data.shape(); + } + + const auto& GetSum = [&](const auto& dim_exprs, const auto& Filter) { + symbol::DimExpr sum{1}; + for (const auto& dim_expr : dim_exprs) { + if (Filter(dim_expr)) { + sum = sum + dim_expr; + } + } + return product; + }; + + const auto& IsNotMinusOne = [&](const symbol::DimExpr& dim_expr) { + if (dim_expr.isa()) { + return dim_expr.dyn_cast() != static_cast(-1); + } + return true; + }; + + const auto& sum_exclude_minus_one = + GetSum(operand_shape_or_data.data().value(), IsNotMinusOne); + + symbol::DimExpr x_ori_dim_on_axis = x_dims_sym[axis]; + + for (uint32_t rst_idx = 0; rst_idx < op->num_results(); rst_idx++) { + auto section_sym = sections_sym[rst_idx]; + x_dims_sym[axis] = IsNotMinusOne(section_sym) + ? section_sym + : x_ori_dim_on_axis - sum_exclude_minus_one; + + shape_analysis->SetShapeOrDataForValue( + op->result(rst_idx), + symbol::ShapeOrDataDimExprs{ + symbol::TensorShapeOrDataDimExprs(x_dims_sym)}); + } + } + return true; + } + + SplitOpInferSymbolicShapeInterfaceModel() + : InferSymbolicShapeInterface::Concept(InferSymbolicShape) {} +}; + struct YieldOpInferSymbolicShapeInterfaceModel : public InferSymbolicShapeInterface::Concept { static inline bool InferSymbolicShape( @@ -196,6 +293,11 @@ OperatorDialect::OperatorDialect(pir::IrContext* ctx) InferSymbolicShapeInterface, ShadowOutputOpInferSymbolicShapeInterfaceModel>())); + info = ctx->GetRegisteredOpInfo(pir::SplitOp::name()); + info.AttachInterface(std::move( + pir::InterfaceValue::Get())); + info = ctx->GetRegisteredOpInfo(pir::YieldOp::name()); info.AttachInterface(std::move( pir::InterfaceValue::Get Date: Tue, 5 Mar 2024 01:54:05 +0000 Subject: [PATCH 02/13] bug fix --- .../paddle_op_infer_sym.cc | 64 ++++++++++++++++- .../pir/dialect/operator/ir/op_dialect.cc | 14 ++-- paddle/phi/api/yaml/legacy_ops.yaml | 1 + .../cinn/symbolic/test_op_infer_sym_shape.py | 68 +++++++++++++++++++ 4 files changed, 139 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.cc index d7ee4fb6781b0f..245bbd937bbc8b 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.cc @@ -1167,8 +1167,68 @@ bool ExpandAsOpInferSymbolicShape( bool SplitOpInferSymbolicShape(pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); + VLOG(0) << "InferSymbolicShape of pd_op.split!"; + // x + auto x_shape_or_data = + shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); + IR_ENFORCE(!x_shape_or_data.data().has_value(), + "Currently InferSymbolicShape of SplitOp only support " + "input without value."); + std::vector x_dims_sym = x_shape_or_data.shape(); + + // axis + CHECK(op->operand_source(2).defining_op()->isa()); + + int64_t axis = op->operand_source(2) + .defining_op() + .attributes() + .at("value") + .dyn_cast() + .data() + .to(); + + // sections + auto sections_shape_or_data = + shape_analysis->GetShapeOrDataForValue(op->operand_source(1)); + std::vector sections_sym; + if (sections_shape_or_data.data().has_value()) { + sections_sym = sections_shape_or_data.data().value(); + } else { + sections_sym = sections_shape_or_data.shape(); + } + + const auto &GetSum = [&](const auto &dim_exprs, const auto &Filter) { + symbol::DimExpr sum{1}; + for (const auto &dim_expr : dim_exprs) { + if (Filter(dim_expr)) { + sum = sum + dim_expr; + } + } + return sum; + }; + + const auto &IsNotMinusOne = [&](const symbol::DimExpr &dim_expr) { + if (dim_expr.isa()) { + return dim_expr.dyn_cast() != static_cast(-1); + } + return true; + }; + + const auto &sum_exclude_minus_one = + GetSum(sections_shape_or_data.data().value(), IsNotMinusOne); + + symbol::DimExpr x_ori_dim_on_axis = x_dims_sym[axis]; + for (uint32_t rst_idx = 0; rst_idx < op->num_results(); rst_idx++) { + auto section_sym = sections_sym[rst_idx]; + x_dims_sym[axis] = IsNotMinusOne(section_sym) + ? section_sym + : x_ori_dim_on_axis - sum_exclude_minus_one; + + shape_analysis->SetShapeOrDataForValue( + op->result(rst_idx), + symbol::ShapeOrDataDimExprs{ + symbol::TensorShapeOrDataDimExprs(x_dims_sym)}); + } return true; } diff --git a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc index 9223d59c6fc672..7ad664384abf76 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc +++ b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc @@ -163,13 +163,15 @@ struct SplitOpInferSymbolicShapeInterfaceModel : public InferSymbolicShapeInterface::Concept { static inline bool InferSymbolicShape( pir::Operation* op, pir::ShapeConstraintIRAnalysis* shape_analysis) { - VLOG(0) << "InferSymbolicShape of SplitOp!"; + VLOG(0) << "InferSymbolicShape of builtin.split!"; + VLOG(0) << "num_operands: " << op->num_operands(); + VLOG(0) << "op: " << op; // x - IR_ENFORCE(!op->operand_source(0).data().has_value(), - "Currently InferSymbolicShape of SplitOp only support " - "input without value."); auto x_shape_or_data = shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); + IR_ENFORCE(!x_shape_or_data.data().has_value(), + "Currently InferSymbolicShape of SplitOp only support " + "input without value."); std::vector x_dims_sym = x_shape_or_data.shape(); // axis @@ -222,7 +224,7 @@ struct SplitOpInferSymbolicShapeInterfaceModel sum = sum + dim_expr; } } - return product; + return sum; }; const auto& IsNotMinusOne = [&](const symbol::DimExpr& dim_expr) { @@ -233,7 +235,7 @@ struct SplitOpInferSymbolicShapeInterfaceModel }; const auto& sum_exclude_minus_one = - GetSum(operand_shape_or_data.data().value(), IsNotMinusOne); + GetSum(sections_shape_or_data.data().value(), IsNotMinusOne); symbol::DimExpr x_ori_dim_on_axis = x_dims_sym[axis]; diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index 9b1d862180903e..43420c66c765ad 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -1078,6 +1078,7 @@ kernel : func : split backward : split_grad + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : split_with_num args : (Tensor x, int num, Scalar(int) axis) diff --git a/test/ir/pir/cinn/symbolic/test_op_infer_sym_shape.py b/test/ir/pir/cinn/symbolic/test_op_infer_sym_shape.py index 4ab27bf657eac9..7cb45131b48b9e 100644 --- a/test/ir/pir/cinn/symbolic/test_op_infer_sym_shape.py +++ b/test/ir/pir/cinn/symbolic/test_op_infer_sym_shape.py @@ -517,5 +517,73 @@ def test_eval_symbolic(self): return True +class SplitNet(paddle.nn.Layer): + def __init__(self): + super().__init__() + + def forward(self, x): + # out = x.split(1) + # out = x.split(1, axis=1) + # out = x + out = paddle.split(x, [-1], axis=1) + # out = paddle.split(x, [1, -1], axis=1) + # out = paddle.split(x, [1, 2, -1], axis=1) + # if x.shape[1] == 6: + # out = paddle.split(x, [1, 2, x.shape[1]], axis=1) + # out = paddle.split(x, [1, 2, 3], axis=1) + + # out = x.split([-1], axis=1) + # out = x.split([1, -1], axis=1) + # out = x.split([1, 2, -1], axis=1) + # if x.shape[1] == 6: + # out = x.split([1, 2, x.shape[1]], axis=1) + # out = x.split([1, 2, 3], axis=1) + + return out + + +class TestSplitOpInferSymbolicShape(TestBase): + def prepare_data(self): + self.cases = [np.random.rand(4, 5, 6)] + + # FIXME: not the expected yet, just a placeholder + self.expected = [ + [ + 'shape[S0, S2], data[NULL]', + 'shape[2, 2, 2], data[NULL]', + 'shape[Add(3, -Add(-3, S0)), 2, 2]', + ] + ] + + def test_eval_symbolic(self): + net = SplitNet() + + for i in range(len(self.cases)): + x = self.cases[i] + x_spec = InputSpec( + shape=[None for index in range(len(x.shape))], dtype='float32' + ) + + input_spec = [x_spec] + net = apply_to_static(net, False, input_spec) + net.eval() + + # check the infer result + sym_shape_str_list = get_sym_shape_str_for_op( + net, input_spec, 'pd_op.slice' + ) + np.testing.assert_equal( + len(sym_shape_str_list), len(self.expected[i]) + ) + for j in range(len(sym_shape_str_list)): + np.testing.assert_equal( + sym_shape_str_list[j].find(self.expected[i][j]), + 0, + f'in case i,j = {i},{j}: output shape ({sym_shape_str_list[0]}) is not expected {(self.expected[i][j])}', + ) + + return True + + if __name__ == '__main__': unittest.main() From 588067733e955c9129b3d0e5c9ad5b041f021e6e Mon Sep 17 00:00:00 2001 From: Tianyu Feng <45195157+fty1777@users.noreply.github.com> Date: Tue, 5 Mar 2024 11:55:09 +0800 Subject: [PATCH 03/13] Update paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.cc Co-authored-by: Bo Zhang <105368690+zhangbopd@users.noreply.github.com> --- .../interface/infer_symbolic_shape/paddle_op_infer_sym.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.cc index 245bbd937bbc8b..c6960fc6ef1595 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.cc @@ -1169,7 +1169,7 @@ bool SplitOpInferSymbolicShape(pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { VLOG(0) << "InferSymbolicShape of pd_op.split!"; // x - auto x_shape_or_data = + const auto& x_shape_or_data = shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); IR_ENFORCE(!x_shape_or_data.data().has_value(), "Currently InferSymbolicShape of SplitOp only support " From ae1282c2c37f927b6a41e854821bcfac64c36166 Mon Sep 17 00:00:00 2001 From: Tianyu Feng <45195157+fty1777@users.noreply.github.com> Date: Tue, 5 Mar 2024 21:59:35 +0800 Subject: [PATCH 04/13] Update paddle/fluid/pir/dialect/operator/ir/op_dialect.cc Co-authored-by: Bo Zhang <105368690+zhangbopd@users.noreply.github.com> --- paddle/fluid/pir/dialect/operator/ir/op_dialect.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc index 7ad664384abf76..7fe2e1e67ab8ea 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc +++ b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc @@ -167,7 +167,7 @@ struct SplitOpInferSymbolicShapeInterfaceModel VLOG(0) << "num_operands: " << op->num_operands(); VLOG(0) << "op: " << op; // x - auto x_shape_or_data = + const auto& x_shape_or_data = shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); IR_ENFORCE(!x_shape_or_data.data().has_value(), "Currently InferSymbolicShape of SplitOp only support " From 6a73674785350d058763ea56192d7fc94a07c1f5 Mon Sep 17 00:00:00 2001 From: Tianyu Feng <45195157+fty1777@users.noreply.github.com> Date: Tue, 5 Mar 2024 21:59:41 +0800 Subject: [PATCH 05/13] Update paddle/fluid/pir/dialect/operator/ir/op_dialect.cc Co-authored-by: Bo Zhang <105368690+zhangbopd@users.noreply.github.com> --- paddle/fluid/pir/dialect/operator/ir/op_dialect.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc index 7fe2e1e67ab8ea..254feb3a390898 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc +++ b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc @@ -172,7 +172,7 @@ struct SplitOpInferSymbolicShapeInterfaceModel IR_ENFORCE(!x_shape_or_data.data().has_value(), "Currently InferSymbolicShape of SplitOp only support " "input without value."); - std::vector x_dims_sym = x_shape_or_data.shape(); + const auto& x_dims_sym = x_shape_or_data.shape(); // axis CHECK(op->operand_source(2).defining_op()->isa()); From 9a4813c7870384bf9b031c962e9f9575020f2632 Mon Sep 17 00:00:00 2001 From: fty1777 Date: Tue, 5 Mar 2024 16:38:13 +0000 Subject: [PATCH 06/13] pd_op.split followed by builtin.split --- .../paddle_op_infer_sym.cc | 92 +++++++++-------- .../pir/dialect/operator/ir/op_dialect.cc | 98 +++---------------- 2 files changed, 61 insertions(+), 129 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.cc index c6960fc6ef1595..15bb705949ece5 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.cc @@ -1167,14 +1167,13 @@ bool ExpandAsOpInferSymbolicShape( bool SplitOpInferSymbolicShape(pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - VLOG(0) << "InferSymbolicShape of pd_op.split!"; - // x - const auto& x_shape_or_data = + // input + const auto &x_shape_or_data = shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); IR_ENFORCE(!x_shape_or_data.data().has_value(), "Currently InferSymbolicShape of SplitOp only support " "input without value."); - std::vector x_dims_sym = x_shape_or_data.shape(); + const auto &x_dims_sym = x_shape_or_data.shape(); // axis CHECK(op->operand_source(2).defining_op()->isa()); @@ -1188,47 +1187,58 @@ bool SplitOpInferSymbolicShape(pir::Operation *op, .to(); // sections - auto sections_shape_or_data = - shape_analysis->GetShapeOrDataForValue(op->operand_source(1)); - std::vector sections_sym; - if (sections_shape_or_data.data().has_value()) { - sections_sym = sections_shape_or_data.data().value(); - } else { - sections_sym = sections_shape_or_data.shape(); - } - - const auto &GetSum = [&](const auto &dim_exprs, const auto &Filter) { - symbol::DimExpr sum{1}; - for (const auto &dim_expr : dim_exprs) { - if (Filter(dim_expr)) { - sum = sum + dim_expr; - } + const std::vector §ions_sym = [&] { + auto sections_shape_or_data = + shape_analysis->GetShapeOrDataForValue(op->operand_source(1)); + std::vector sections_sym; + if (sections_shape_or_data.data().has_value()) { + sections_sym = sections_shape_or_data.data().value(); + } else { + sections_sym = sections_shape_or_data.shape(); } - return sum; - }; + return sections_sym; + }(); - const auto &IsNotMinusOne = [&](const symbol::DimExpr &dim_expr) { - if (dim_expr.isa()) { - return dim_expr.dyn_cast() != static_cast(-1); + // output + const symbol::TensorListShapeOrDataDimExprs &output_shape_data_list = [&] { + const auto &GetSum = [&](const auto &dim_exprs, const auto &Filter) { + symbol::DimExpr sum{1}; + for (const auto &dim_expr : dim_exprs) { + if (Filter(dim_expr)) { + sum = sum + dim_expr; + } + } + return sum; + }; + const auto &IsNotMinusOne = [&](const symbol::DimExpr &dim_expr) { + if (dim_expr.isa()) { + return dim_expr.dyn_cast() != static_cast(-1); + } + return true; + }; + const auto &sum_exclude_minus_one = GetSum(sections_sym, IsNotMinusOne); + + symbol::TensorListShapeOrDataDimExprs shape_data_list; + std::vector output_dims_sym = x_dims_sym; + for (uint32_t idx = 0; idx < sections_sym.size(); idx++) { + const auto section_sym = sections_sym[idx]; + output_dims_sym[axis] = IsNotMinusOne(section_sym) + ? section_sym + : x_dims_sym[axis] - sum_exclude_minus_one; + + // VLOG(0) << "FTY DEBUG START"; + // for (const auto &dim_expr : output_dims_sym) { + // VLOG(0) << "FTY DEBUG " << dim_expr; + // } + // VLOG(0) << "FTY DEBUG END"; + shape_data_list.push_back(symbol::TensorShapeOrDataDimExprs(x_dims_sym)); } - return true; - }; - - const auto &sum_exclude_minus_one = - GetSum(sections_shape_or_data.data().value(), IsNotMinusOne); + return shape_data_list; + }(); - symbol::DimExpr x_ori_dim_on_axis = x_dims_sym[axis]; - for (uint32_t rst_idx = 0; rst_idx < op->num_results(); rst_idx++) { - auto section_sym = sections_sym[rst_idx]; - x_dims_sym[axis] = IsNotMinusOne(section_sym) - ? section_sym - : x_ori_dim_on_axis - sum_exclude_minus_one; + shape_analysis->SetShapeOrDataForValue( + op->result(0), symbol::ShapeOrDataDimExprs{output_shape_data_list}); - shape_analysis->SetShapeOrDataForValue( - op->result(rst_idx), - symbol::ShapeOrDataDimExprs{ - symbol::TensorShapeOrDataDimExprs(x_dims_sym)}); - } return true; } @@ -1525,8 +1535,6 @@ bool RepeatInterleaveOpInferSymbolicShape( } bool SplitWithNumOpInferSymbolicShape( pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); return true; } bool TrilIndicesOpInferSymbolicShape( diff --git a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc index 254feb3a390898..54fb53766c36e8 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc +++ b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc @@ -163,93 +163,17 @@ struct SplitOpInferSymbolicShapeInterfaceModel : public InferSymbolicShapeInterface::Concept { static inline bool InferSymbolicShape( pir::Operation* op, pir::ShapeConstraintIRAnalysis* shape_analysis) { - VLOG(0) << "InferSymbolicShape of builtin.split!"; - VLOG(0) << "num_operands: " << op->num_operands(); - VLOG(0) << "op: " << op; - // x - const auto& x_shape_or_data = - shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); - IR_ENFORCE(!x_shape_or_data.data().has_value(), - "Currently InferSymbolicShape of SplitOp only support " - "input without value."); - const auto& x_dims_sym = x_shape_or_data.shape(); - - // axis - CHECK(op->operand_source(2).defining_op()->isa()); - - int64_t axis = op->operand_source(2) - .defining_op() - .attributes() - .at("value") - .dyn_cast() - .data() - .to(); - - // sections or num - if (op->operand_source(1).defining_op()->isa()) { - // num - VLOG(0) << "Num!"; - int64_t num = op->operand_source(1) - .defining_op() - .attributes() - .at("value") - .dyn_cast() - .data() - .to(); - x_dims_sym[axis] = x_dims_sym[axis] / num; - - // all the result have the same shape - for (uint32_t rst_idx = 0; rst_idx < op->num_results(); rst_idx++) { - shape_analysis->SetShapeOrDataForValue( - op->result(rst_idx), - symbol::ShapeOrDataDimExprs{ - symbol::TensorShapeOrDataDimExprs(x_dims_sym)}); - } - } else { - // sections - VLOG(0) << "Sections!"; - auto sections_shape_or_data = - shape_analysis->GetShapeOrDataForValue(op->operand_source(1)); - std::vector sections_sym; - if (sections_shape_or_data.data().has_value()) { - sections_sym = sections_shape_or_data.data().value(); - } else { - sections_sym = sections_shape_or_data.shape(); - } - - const auto& GetSum = [&](const auto& dim_exprs, const auto& Filter) { - symbol::DimExpr sum{1}; - for (const auto& dim_expr : dim_exprs) { - if (Filter(dim_expr)) { - sum = sum + dim_expr; - } - } - return sum; - }; - - const auto& IsNotMinusOne = [&](const symbol::DimExpr& dim_expr) { - if (dim_expr.isa()) { - return dim_expr.dyn_cast() != static_cast(-1); - } - return true; - }; - - const auto& sum_exclude_minus_one = - GetSum(sections_shape_or_data.data().value(), IsNotMinusOne); - - symbol::DimExpr x_ori_dim_on_axis = x_dims_sym[axis]; - - for (uint32_t rst_idx = 0; rst_idx < op->num_results(); rst_idx++) { - auto section_sym = sections_sym[rst_idx]; - x_dims_sym[axis] = IsNotMinusOne(section_sym) - ? section_sym - : x_ori_dim_on_axis - sum_exclude_minus_one; - - shape_analysis->SetShapeOrDataForValue( - op->result(rst_idx), - symbol::ShapeOrDataDimExprs{ - symbol::TensorShapeOrDataDimExprs(x_dims_sym)}); - } + const auto& shape_data_list = + shape_analysis->GetShapeOrDataForValue(op->operand_source(0)) + .dyn_cast(); + + for (uint32_t rst_idx = 0; rst_idx < op->num_results(); rst_idx++) { + IR_ENFORCE(!shape_data_list[rst_idx].data().has_value(), + "Currently InferSymbolicShape of SplitOp only support " + "input without value."); + shape_analysis->SetShapeOrDataForValue( + op->result(rst_idx), + symbol::ShapeOrDataDimExprs{shape_data_list[rst_idx]}); } return true; } From 82fea4bf908924a3a1911dfa25365bd4f3026147 Mon Sep 17 00:00:00 2001 From: fty1777 Date: Tue, 5 Mar 2024 18:34:42 +0000 Subject: [PATCH 07/13] pd_op.split infer sym shape bugfix and unittest; fix op infer sym error outputs --- .../paddle_op_infer_sym.cc | 32 ++++++++--- .../cinn/symbolic/test_op_infer_sym_shape.py | 56 ++++++++++--------- .../symbolic/test_unary_op_infer_sym_shape.py | 2 +- 3 files changed, 57 insertions(+), 33 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.cc index 15bb705949ece5..31357acfde2908 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.cc @@ -1202,7 +1202,7 @@ bool SplitOpInferSymbolicShape(pir::Operation *op, // output const symbol::TensorListShapeOrDataDimExprs &output_shape_data_list = [&] { const auto &GetSum = [&](const auto &dim_exprs, const auto &Filter) { - symbol::DimExpr sum{1}; + symbol::DimExpr sum{0}; for (const auto &dim_expr : dim_exprs) { if (Filter(dim_expr)) { sum = sum + dim_expr; @@ -1210,6 +1210,14 @@ bool SplitOpInferSymbolicShape(pir::Operation *op, } return sum; }; + const auto &All = [&](const auto &dim_exprs, const auto &Cond) { + for (const auto &dim_expr : dim_exprs) { + if (!Cond(dim_expr)) { + return false; + } + } + return true; + }; const auto &IsNotMinusOne = [&](const symbol::DimExpr &dim_expr) { if (dim_expr.isa()) { return dim_expr.dyn_cast() != static_cast(-1); @@ -1218,20 +1226,30 @@ bool SplitOpInferSymbolicShape(pir::Operation *op, }; const auto &sum_exclude_minus_one = GetSum(sections_sym, IsNotMinusOne); + const bool &all_sections_sym_not_minus_one = + All(sections_sym, IsNotMinusOne); + if (all_sections_sym_not_minus_one) { + shape_analysis->CreateDimExprBuilder().CstrEq(x_dims_sym[axis], + sum_exclude_minus_one); + } + symbol::TensorListShapeOrDataDimExprs shape_data_list; std::vector output_dims_sym = x_dims_sym; + if (!all_sections_sym_not_minus_one && sections_sym.size() == 1) { + VLOG(3) << "[SplitOp]-1 is the only split section. The output shape is " + "identical to the input shape."; + shape_data_list.push_back( + symbol::TensorShapeOrDataDimExprs(output_dims_sym)); + return shape_data_list; + } for (uint32_t idx = 0; idx < sections_sym.size(); idx++) { const auto section_sym = sections_sym[idx]; output_dims_sym[axis] = IsNotMinusOne(section_sym) ? section_sym : x_dims_sym[axis] - sum_exclude_minus_one; - // VLOG(0) << "FTY DEBUG START"; - // for (const auto &dim_expr : output_dims_sym) { - // VLOG(0) << "FTY DEBUG " << dim_expr; - // } - // VLOG(0) << "FTY DEBUG END"; - shape_data_list.push_back(symbol::TensorShapeOrDataDimExprs(x_dims_sym)); + shape_data_list.push_back( + symbol::TensorShapeOrDataDimExprs(output_dims_sym)); } return shape_data_list; }(); diff --git a/test/ir/pir/cinn/symbolic/test_op_infer_sym_shape.py b/test/ir/pir/cinn/symbolic/test_op_infer_sym_shape.py index 7cb45131b48b9e..594d2c7dac3cc9 100644 --- a/test/ir/pir/cinn/symbolic/test_op_infer_sym_shape.py +++ b/test/ir/pir/cinn/symbolic/test_op_infer_sym_shape.py @@ -351,7 +351,7 @@ def test_eval_symbolic(self): np.testing.assert_equal( sym_shape_str_list[j].find(self.expected[i][j]), 0, - f'in case i,j = {i},{j}: output shape ({sym_shape_str_list[0]}) is not expected {(self.expected[i][j])}', + f'in case i,j = {i},{j}: output shape ({sym_shape_str_list[j]}) is not expected {(self.expected[i][j])}', ) return True @@ -403,7 +403,7 @@ def test_eval_symbolic(self): np.testing.assert_equal( sym_shape_str_list[j].find(self.expected[i][j]), 0, - f'in case i,j = {i},{j}: output shape ({sym_shape_str_list[0]}) is not expected {(self.expected[i][j])}', + f'in case i,j = {i},{j}: output shape ({sym_shape_str_list[j]}) is not expected {(self.expected[i][j])}', ) return True @@ -453,7 +453,7 @@ def test_eval_symbolic(self): np.testing.assert_equal( sym_shape_str_list[j].find(self.expected[i][j]), 0, - f'in case i,j = {i},{j}: output shape ({sym_shape_str_list[0]}) is not expected {(self.expected[i][j])}', + f'in case i,j = {i},{j}: output shape ({sym_shape_str_list[j]}) is not expected {(self.expected[i][j])}', ) return True @@ -511,7 +511,7 @@ def test_eval_symbolic(self): np.testing.assert_equal( sym_shape_str_list[j].find(self.expected[i][j]), 0, - f'in case i,j = {i},{j}: output shape ({sym_shape_str_list[0]}) is not expected {(self.expected[i][j])}', + f'in case i,j = {i},{j}: output shape ({sym_shape_str_list[j]}) is not expected {(self.expected[i][j])}', ) return True @@ -522,36 +522,38 @@ def __init__(self): super().__init__() def forward(self, x): - # out = x.split(1) - # out = x.split(1, axis=1) - # out = x out = paddle.split(x, [-1], axis=1) - # out = paddle.split(x, [1, -1], axis=1) - # out = paddle.split(x, [1, 2, -1], axis=1) - # if x.shape[1] == 6: - # out = paddle.split(x, [1, 2, x.shape[1]], axis=1) - # out = paddle.split(x, [1, 2, 3], axis=1) - - # out = x.split([-1], axis=1) - # out = x.split([1, -1], axis=1) - # out = x.split([1, 2, -1], axis=1) - # if x.shape[1] == 6: - # out = x.split([1, 2, x.shape[1]], axis=1) - # out = x.split([1, 2, 3], axis=1) + out = paddle.split(x, [1, 2, -1], axis=1) + out = paddle.split(x, [1, -1], axis=1) + out = paddle.split(x, [1, 2, 3], axis=1) + out = paddle.split(x, [1, 2, x.shape[1]], axis=1) + + out = x.split([-1], axis=1) + out = x.split([1, 2, -1], axis=1) + out = x.split([1, -1], axis=1) + out = x.split([1, 2, 3], axis=1) + out = x.split([1, 2, x.shape[1]], axis=1) return out class TestSplitOpInferSymbolicShape(TestBase): def prepare_data(self): - self.cases = [np.random.rand(4, 5, 6)] + self.cases = [np.random.rand(4, 6, 5)] # FIXME: not the expected yet, just a placeholder self.expected = [ [ - 'shape[S0, S2], data[NULL]', - 'shape[2, 2, 2], data[NULL]', - 'shape[Add(3, -Add(-3, S0)), 2, 2]', + 'shape[S0, S1, S2], data[NULL]', + 'shape[S0, 1, S2], data[NULL], shape[S0, 2, S2], data[NULL], shape[S0, Add(S1, -3), S2], data[NULL]', + 'shape[S0, 1, S2], data[NULL], shape[S0, Add(S1, -1), S2], data[NULL]', + 'shape[S0, 1, S2], data[NULL], shape[S0, 2, S2], data[NULL], shape[S0, 3, S2], data[NULL]', + 'shape[S0, 1, S2], data[NULL], shape[S0, 2, S2], data[NULL], shape[S0, S1, S2], data[NULL]', + 'shape[S0, S1, S2], data[NULL]', + 'shape[S0, 1, S2], data[NULL], shape[S0, 2, S2], data[NULL], shape[S0, Add(S1, -3), S2], data[NULL]', + 'shape[S0, 1, S2], data[NULL], shape[S0, Add(S1, -1), S2], data[NULL]', + 'shape[S0, 1, S2], data[NULL], shape[S0, 2, S2], data[NULL], shape[S0, 3, S2], data[NULL]', + 'shape[S0, 1, S2], data[NULL], shape[S0, 2, S2], data[NULL], shape[S0, S1, S2], data[NULL]', ] ] @@ -570,7 +572,7 @@ def test_eval_symbolic(self): # check the infer result sym_shape_str_list = get_sym_shape_str_for_op( - net, input_spec, 'pd_op.slice' + net, input_spec, 'pd_op.split' ) np.testing.assert_equal( len(sym_shape_str_list), len(self.expected[i]) @@ -579,9 +581,13 @@ def test_eval_symbolic(self): np.testing.assert_equal( sym_shape_str_list[j].find(self.expected[i][j]), 0, - f'in case i,j = {i},{j}: output shape ({sym_shape_str_list[0]}) is not expected {(self.expected[i][j])}', + f'in case i,j = {i},{j}: output shape ({sym_shape_str_list[j]}) is not expected {(self.expected[i][j])}', ) + # TODO(fty1777): Add builtin.split op infer symbolic shape test + # Not added because attribute `sym_shape_str` does not support multi-output op now. + # See also: paddle/fluid/pir/transforms/shape_optimization_pass.cc:144. + return True diff --git a/test/ir/pir/cinn/symbolic/test_unary_op_infer_sym_shape.py b/test/ir/pir/cinn/symbolic/test_unary_op_infer_sym_shape.py index 5260475b45f1e8..f472e9c6718685 100644 --- a/test/ir/pir/cinn/symbolic/test_unary_op_infer_sym_shape.py +++ b/test/ir/pir/cinn/symbolic/test_unary_op_infer_sym_shape.py @@ -102,7 +102,7 @@ def test_eval_symbolic(self): np.testing.assert_equal( sym_shape_str_list[j].find(self.expected[i][j]), 0, - f'in case i,j = {i},{j}: output shape ({sym_shape_str_list[0]}) is not expected {(self.expected[i][j])}', + f'in case i,j = {i},{j}: output shape ({sym_shape_str_list[j]}) is not expected {(self.expected[i][j])}', ) return True From 2da326a67b844444f2b18896004a9105c871b7f6 Mon Sep 17 00:00:00 2001 From: fty1777 Date: Tue, 5 Mar 2024 18:43:21 +0000 Subject: [PATCH 08/13] recover SplitWithNumOpInferSymbolicShape Unimplemented exception raising --- .../interface/infer_symbolic_shape/paddle_op_infer_sym.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.cc index 31357acfde2908..30ffe102d6d8a6 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.cc @@ -1553,6 +1553,8 @@ bool RepeatInterleaveOpInferSymbolicShape( } bool SplitWithNumOpInferSymbolicShape( pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + PADDLE_THROW(phi::errors::Unimplemented( + op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); return true; } bool TrilIndicesOpInferSymbolicShape( From a96317c22acde5f406ec43c211c6a56bef924e1b Mon Sep 17 00:00:00 2001 From: fty1777 Date: Thu, 7 Mar 2024 05:15:22 +0000 Subject: [PATCH 09/13] code refinement --- .../interface/infer_symbolic_shape/paddle_op_infer_sym.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.cc index 30ffe102d6d8a6..e9cbf9fc6f0da1 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.cc @@ -1188,7 +1188,7 @@ bool SplitOpInferSymbolicShape(pir::Operation *op, // sections const std::vector §ions_sym = [&] { - auto sections_shape_or_data = + const auto §ions_shape_or_data = shape_analysis->GetShapeOrDataForValue(op->operand_source(1)); std::vector sections_sym; if (sections_shape_or_data.data().has_value()) { @@ -1243,7 +1243,7 @@ bool SplitOpInferSymbolicShape(pir::Operation *op, return shape_data_list; } for (uint32_t idx = 0; idx < sections_sym.size(); idx++) { - const auto section_sym = sections_sym[idx]; + const auto §ion_sym = sections_sym[idx]; output_dims_sym[axis] = IsNotMinusOne(section_sym) ? section_sym : x_dims_sym[axis] - sum_exclude_minus_one; From aedf013d30dc8b8b727b7bb6bfe4672948a5b405 Mon Sep 17 00:00:00 2001 From: fty1777 Date: Fri, 8 Mar 2024 08:29:26 +0000 Subject: [PATCH 10/13] Rewrite PADDLE_ENFORCE --- .../interface/infer_symbolic_shape/paddle_op_infer_sym.cc | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.cc index e9cbf9fc6f0da1..087edcaa001dc9 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.cc @@ -1170,9 +1170,11 @@ bool SplitOpInferSymbolicShape(pir::Operation *op, // input const auto &x_shape_or_data = shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); - IR_ENFORCE(!x_shape_or_data.data().has_value(), - "Currently InferSymbolicShape of SplitOp only support " - "input without value."); + PADDLE_ENFORCE_EQ(x_shape_or_data.data().has_value(), + false, + phi::errors::InvalidArgument( + "InferSymbolicShape of SplitOp only support input with " + "value now.")); const auto &x_dims_sym = x_shape_or_data.shape(); // axis From a2daf697894eb2babfbef89a613119d464c64492 Mon Sep 17 00:00:00 2001 From: fty1777 Date: Fri, 8 Mar 2024 08:40:02 +0000 Subject: [PATCH 11/13] remove incorrect comments --- test/ir/pir/cinn/symbolic/test_op_infer_sym_shape.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/ir/pir/cinn/symbolic/test_op_infer_sym_shape.py b/test/ir/pir/cinn/symbolic/test_op_infer_sym_shape.py index 594d2c7dac3cc9..68214a52b9ff15 100644 --- a/test/ir/pir/cinn/symbolic/test_op_infer_sym_shape.py +++ b/test/ir/pir/cinn/symbolic/test_op_infer_sym_shape.py @@ -541,7 +541,6 @@ class TestSplitOpInferSymbolicShape(TestBase): def prepare_data(self): self.cases = [np.random.rand(4, 6, 5)] - # FIXME: not the expected yet, just a placeholder self.expected = [ [ 'shape[S0, S1, S2], data[NULL]', From 71c71abdab0f29d0fb48d170899b0fb7c9a8e941 Mon Sep 17 00:00:00 2001 From: fty1777 Date: Fri, 8 Mar 2024 15:34:12 +0000 Subject: [PATCH 12/13] Rewrite PADDLE_ENFORCE --- paddle/fluid/pir/dialect/operator/ir/op_dialect.cc | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc index 54fb53766c36e8..6728ab03360380 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc +++ b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc @@ -168,9 +168,10 @@ struct SplitOpInferSymbolicShapeInterfaceModel .dyn_cast(); for (uint32_t rst_idx = 0; rst_idx < op->num_results(); rst_idx++) { - IR_ENFORCE(!shape_data_list[rst_idx].data().has_value(), - "Currently InferSymbolicShape of SplitOp only support " - "input without value."); + PADDLE_ENFORCE_EQ(shape_data_list[rst_idx].data().has_value(), + false, + "Currently InferSymbolicShape of SplitOp only support " + "input without value."); shape_analysis->SetShapeOrDataForValue( op->result(rst_idx), symbol::ShapeOrDataDimExprs{shape_data_list[rst_idx]}); From 5a5ca1bda2fba697bbb2ce8d68db2675a734cfc1 Mon Sep 17 00:00:00 2001 From: fty1777 Date: Fri, 8 Mar 2024 15:54:12 +0000 Subject: [PATCH 13/13] Rewrite PADDLE_ENFORCE --- paddle/fluid/pir/dialect/operator/ir/op_dialect.cc | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc index 6728ab03360380..67fdbe5e8693d8 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc +++ b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc @@ -168,10 +168,12 @@ struct SplitOpInferSymbolicShapeInterfaceModel .dyn_cast(); for (uint32_t rst_idx = 0; rst_idx < op->num_results(); rst_idx++) { - PADDLE_ENFORCE_EQ(shape_data_list[rst_idx].data().has_value(), - false, - "Currently InferSymbolicShape of SplitOp only support " - "input without value."); + PADDLE_ENFORCE_EQ( + shape_data_list[rst_idx].data().has_value(), + false, + paddle::platform::errors::InvalidArgument( + "Currently InferSymbolicShape of SplitOp only support " + "input without value.")); shape_analysis->SetShapeOrDataForValue( op->result(rst_idx), symbol::ShapeOrDataDimExprs{shape_data_list[rst_idx]});