Skip to content

Commit ce85415

Browse files
fty1777zhangboSJTU
authored andcommitted
Symbolic shape inference support for pd_op.split and builtin.split (PaddlePaddle#62394)
* WIP: builtin.split op infer sym shape * bug fix * Update paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.cc Co-authored-by: Bo Zhang <105368690+zhangbopd@users.noreply.github.com> * Update paddle/fluid/pir/dialect/operator/ir/op_dialect.cc Co-authored-by: Bo Zhang <105368690+zhangbopd@users.noreply.github.com> * Update paddle/fluid/pir/dialect/operator/ir/op_dialect.cc Co-authored-by: Bo Zhang <105368690+zhangbopd@users.noreply.github.com> * pd_op.split followed by builtin.split * pd_op.split infer sym shape bugfix and unittest; fix op infer sym error outputs * recover SplitWithNumOpInferSymbolicShape Unimplemented exception raising * code refinement * Rewrite PADDLE_ENFORCE * remove incorrect comments * Rewrite PADDLE_ENFORCE * Rewrite PADDLE_ENFORCE --------- Co-authored-by: Bo Zhang <105368690+zhangbopd@users.noreply.github.com>
1 parent 5089edc commit ce85415

5 files changed

Lines changed: 202 additions & 7 deletions

File tree

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.cc

Lines changed: 92 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -958,8 +958,98 @@ bool ExpandAsOpInferSymbolicShape(
958958

959959
bool SplitOpInferSymbolicShape(pir::Operation *op,
960960
pir::ShapeConstraintIRAnalysis *shape_analysis) {
961-
PADDLE_THROW(phi::errors::Unimplemented(
962-
op->name() + " 's InferSymbolicShape interface is NOT implemented now."));
961+
// input
962+
const auto &x_shape_or_data =
963+
shape_analysis->GetShapeOrDataForValue(op->operand_source(0));
964+
PADDLE_ENFORCE_EQ(x_shape_or_data.data().has_value(),
965+
false,
966+
phi::errors::InvalidArgument(
967+
"InferSymbolicShape of SplitOp only support input with "
968+
"value now."));
969+
const auto &x_dims_sym = x_shape_or_data.shape();
970+
971+
// axis
972+
CHECK(op->operand_source(2).defining_op()->isa<paddle::dialect::FullOp>());
973+
974+
int64_t axis = op->operand_source(2)
975+
.defining_op<paddle::dialect::FullOp>()
976+
.attributes()
977+
.at("value")
978+
.dyn_cast<paddle::dialect::ScalarAttribute>()
979+
.data()
980+
.to<int64_t>();
981+
982+
// sections
983+
const std::vector<symbol::DimExpr> &sections_sym = [&] {
984+
const auto &sections_shape_or_data =
985+
shape_analysis->GetShapeOrDataForValue(op->operand_source(1));
986+
std::vector<symbol::DimExpr> sections_sym;
987+
if (sections_shape_or_data.data().has_value()) {
988+
sections_sym = sections_shape_or_data.data().value();
989+
} else {
990+
sections_sym = sections_shape_or_data.shape();
991+
}
992+
return sections_sym;
993+
}();
994+
995+
// output
996+
const symbol::TensorListShapeOrDataDimExprs &output_shape_data_list = [&] {
997+
const auto &GetSum = [&](const auto &dim_exprs, const auto &Filter) {
998+
symbol::DimExpr sum{0};
999+
for (const auto &dim_expr : dim_exprs) {
1000+
if (Filter(dim_expr)) {
1001+
sum = sum + dim_expr;
1002+
}
1003+
}
1004+
return sum;
1005+
};
1006+
const auto &All = [&](const auto &dim_exprs, const auto &Cond) {
1007+
for (const auto &dim_expr : dim_exprs) {
1008+
if (!Cond(dim_expr)) {
1009+
return false;
1010+
}
1011+
}
1012+
return true;
1013+
};
1014+
const auto &IsNotMinusOne = [&](const symbol::DimExpr &dim_expr) {
1015+
if (dim_expr.isa<int64_t>()) {
1016+
return dim_expr.dyn_cast<int64_t>() != static_cast<int64_t>(-1);
1017+
}
1018+
return true;
1019+
};
1020+
const auto &sum_exclude_minus_one = GetSum(sections_sym, IsNotMinusOne);
1021+
1022+
const bool &all_sections_sym_not_minus_one =
1023+
All(sections_sym, IsNotMinusOne);
1024+
if (all_sections_sym_not_minus_one) {
1025+
shape_analysis->CreateDimExprBuilder().CstrEq(x_dims_sym[axis],
1026+
sum_exclude_minus_one);
1027+
}
1028+
1029+
symbol::TensorListShapeOrDataDimExprs shape_data_list;
1030+
std::vector<symbol::DimExpr> output_dims_sym = x_dims_sym;
1031+
if (!all_sections_sym_not_minus_one && sections_sym.size() == 1) {
1032+
VLOG(3) << "[SplitOp]-1 is the only split section. The output shape is "
1033+
"identical to the input shape.";
1034+
shape_data_list.push_back(
1035+
symbol::TensorShapeOrDataDimExprs(output_dims_sym));
1036+
return shape_data_list;
1037+
}
1038+
for (uint32_t idx = 0; idx < sections_sym.size(); idx++) {
1039+
const auto &section_sym = sections_sym[idx];
1040+
output_dims_sym[axis] = IsNotMinusOne(section_sym)
1041+
? section_sym
1042+
: x_dims_sym[axis] - sum_exclude_minus_one;
1043+
1044+
shape_data_list.push_back(
1045+
symbol::TensorShapeOrDataDimExprs(output_dims_sym));
1046+
}
1047+
return shape_data_list;
1048+
}();
1049+
1050+
shape_analysis->SetShapeOrDataForValue(
1051+
op->result(0), symbol::ShapeOrDataDimExprs{output_shape_data_list});
1052+
9631053
return true;
9641054
}
9651055

paddle/fluid/pir/dialect/operator/ir/op_dialect.cc

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,32 @@ struct ShadowOutputOpInferSymbolicShapeInterfaceModel
159159
: InferSymbolicShapeInterface::Concept(InferSymbolicShape) {}
160160
};
161161

162+
struct SplitOpInferSymbolicShapeInterfaceModel
163+
: public InferSymbolicShapeInterface::Concept {
164+
static inline bool InferSymbolicShape(
165+
pir::Operation* op, pir::ShapeConstraintIRAnalysis* shape_analysis) {
166+
const auto& shape_data_list =
167+
shape_analysis->GetShapeOrDataForValue(op->operand_source(0))
168+
.dyn_cast<symbol::TensorListShapeOrDataDimExprs>();
169+
170+
for (uint32_t rst_idx = 0; rst_idx < op->num_results(); rst_idx++) {
171+
PADDLE_ENFORCE_EQ(
172+
shape_data_list[rst_idx].data().has_value(),
173+
false,
174+
paddle::platform::errors::InvalidArgument(
175+
"Currently InferSymbolicShape of SplitOp only support "
176+
"input without value."));
177+
shape_analysis->SetShapeOrDataForValue(
178+
op->result(rst_idx),
179+
symbol::ShapeOrDataDimExprs{shape_data_list[rst_idx]});
180+
}
181+
return true;
182+
}
183+
184+
SplitOpInferSymbolicShapeInterfaceModel()
185+
: InferSymbolicShapeInterface::Concept(InferSymbolicShape) {}
186+
};
187+
162188
struct YieldOpInferSymbolicShapeInterfaceModel
163189
: public InferSymbolicShapeInterface::Concept {
164190
static inline bool InferSymbolicShape(
@@ -196,6 +222,11 @@ OperatorDialect::OperatorDialect(pir::IrContext* ctx)
196222
InferSymbolicShapeInterface,
197223
ShadowOutputOpInferSymbolicShapeInterfaceModel>()));
198224

225+
info = ctx->GetRegisteredOpInfo(pir::SplitOp::name());
226+
info.AttachInterface(std::move(
227+
pir::InterfaceValue::Get<InferSymbolicShapeInterface,
228+
SplitOpInferSymbolicShapeInterfaceModel>()));
229+
199230
info = ctx->GetRegisteredOpInfo(pir::YieldOp::name());
200231
info.AttachInterface(std::move(
201232
pir::InterfaceValue::Get<InferSymbolicShapeInterface,

paddle/phi/api/yaml/legacy_ops.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1099,6 +1099,7 @@
10991099
kernel :
11001100
func : split
11011101
backward : split_grad
1102+
interfaces : paddle::dialect::InferSymbolicShapeInterface
11021103

11031104
- op : split_with_num
11041105
args : (Tensor x, int num, Scalar(int) axis)

test/ir/pir/cinn/symbolic/test_op_infer_sym_shape.py

Lines changed: 77 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ def test_eval_symbolic(self):
351351
np.testing.assert_equal(
352352
sym_shape_str_list[j].find(self.expected[i][j]),
353353
0,
354-
f'in case i,j = {i},{j}: output shape ({sym_shape_str_list[0]}) is not expected {(self.expected[i][j])}',
354+
f'in case i,j = {i},{j}: output shape ({sym_shape_str_list[j]}) is not expected {(self.expected[i][j])}',
355355
)
356356

357357
return True
@@ -403,7 +403,7 @@ def test_eval_symbolic(self):
403403
np.testing.assert_equal(
404404
sym_shape_str_list[j].find(self.expected[i][j]),
405405
0,
406-
f'in case i,j = {i},{j}: output shape ({sym_shape_str_list[0]}) is not expected {(self.expected[i][j])}',
406+
f'in case i,j = {i},{j}: output shape ({sym_shape_str_list[j]}) is not expected {(self.expected[i][j])}',
407407
)
408408

409409
return True
@@ -453,7 +453,7 @@ def test_eval_symbolic(self):
453453
np.testing.assert_equal(
454454
sym_shape_str_list[j].find(self.expected[i][j]),
455455
0,
456-
f'in case i,j = {i},{j}: output shape ({sym_shape_str_list[0]}) is not expected {(self.expected[i][j])}',
456+
f'in case i,j = {i},{j}: output shape ({sym_shape_str_list[j]}) is not expected {(self.expected[i][j])}',
457457
)
458458

459459
return True
@@ -512,11 +512,84 @@ def test_eval_symbolic(self):
512512
np.testing.assert_equal(
513513
sym_shape_str_list[j].find(self.expected[i][j]),
514514
0,
515-
f'in case i,j = {i},{j}: output shape ({sym_shape_str_list[0]}) is not expected {(self.expected[i][j])}',
515+
f'in case i,j = {i},{j}: output shape ({sym_shape_str_list[j]}) is not expected {(self.expected[i][j])}',
516516
)
517517

518518
return True
519519

520520

521+
class SplitNet(paddle.nn.Layer):
522+
def __init__(self):
523+
super().__init__()
524+
525+
def forward(self, x):
526+
out = paddle.split(x, [-1], axis=1)
527+
out = paddle.split(x, [1, 2, -1], axis=1)
528+
out = paddle.split(x, [1, -1], axis=1)
529+
out = paddle.split(x, [1, 2, 3], axis=1)
530+
out = paddle.split(x, [1, 2, x.shape[1]], axis=1)
531+
532+
out = x.split([-1], axis=1)
533+
out = x.split([1, 2, -1], axis=1)
534+
out = x.split([1, -1], axis=1)
535+
out = x.split([1, 2, 3], axis=1)
536+
out = x.split([1, 2, x.shape[1]], axis=1)
537+
538+
return out
539+
540+
541+
class TestSplitOpInferSymbolicShape(TestBase):
542+
def prepare_data(self):
543+
self.cases = [np.random.rand(4, 6, 5)]
544+
545+
self.expected = [
546+
[
547+
'shape[S0, S1, S2], data[NULL]',
548+
'shape[S0, 1, S2], data[NULL], shape[S0, 2, S2], data[NULL], shape[S0, Add(S1, -3), S2], data[NULL]',
549+
'shape[S0, 1, S2], data[NULL], shape[S0, Add(S1, -1), S2], data[NULL]',
550+
'shape[S0, 1, S2], data[NULL], shape[S0, 2, S2], data[NULL], shape[S0, 3, S2], data[NULL]',
551+
'shape[S0, 1, S2], data[NULL], shape[S0, 2, S2], data[NULL], shape[S0, S1, S2], data[NULL]',
552+
'shape[S0, S1, S2], data[NULL]',
553+
'shape[S0, 1, S2], data[NULL], shape[S0, 2, S2], data[NULL], shape[S0, Add(S1, -3), S2], data[NULL]',
554+
'shape[S0, 1, S2], data[NULL], shape[S0, Add(S1, -1), S2], data[NULL]',
555+
'shape[S0, 1, S2], data[NULL], shape[S0, 2, S2], data[NULL], shape[S0, 3, S2], data[NULL]',
556+
'shape[S0, 1, S2], data[NULL], shape[S0, 2, S2], data[NULL], shape[S0, S1, S2], data[NULL]',
557+
]
558+
]
559+
560+
def test_eval_symbolic(self):
561+
net = SplitNet()
562+
563+
for i in range(len(self.cases)):
564+
x = self.cases[i]
565+
x_spec = InputSpec(
566+
shape=[None for index in range(len(x.shape))], dtype='float32'
567+
)
568+
569+
input_spec = [x_spec]
570+
net = apply_to_static(net, False, input_spec)
571+
net.eval()
572+
573+
# check the infer result
574+
sym_shape_str_list = get_sym_shape_str_for_op(
575+
net, input_spec, 'pd_op.split'
576+
)
577+
np.testing.assert_equal(
578+
len(sym_shape_str_list), len(self.expected[i])
579+
)
580+
for j in range(len(sym_shape_str_list)):
581+
np.testing.assert_equal(
582+
sym_shape_str_list[j].find(self.expected[i][j]),
583+
0,
584+
f'in case i,j = {i},{j}: output shape ({sym_shape_str_list[j]}) is not expected {(self.expected[i][j])}',
585+
)
586+
587+
# TODO(fty1777): Add builtin.split op infer symbolic shape test
588+
# Not added because attribute `sym_shape_str` does not support multi-output op now.
589+
# See also: paddle/fluid/pir/transforms/shape_optimization_pass.cc:144.
590+
591+
return True
592+
593+
521594
if __name__ == '__main__':
522595
unittest.main()

test/ir/pir/cinn/symbolic/test_unary_op_infer_sym_shape.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def test_eval_symbolic(self):
102102
np.testing.assert_equal(
103103
sym_shape_str_list[j].find(self.expected[i][j]),
104104
0,
105-
f'in case i,j = {i},{j}: output shape ({sym_shape_str_list[0]}) is not expected {(self.expected[i][j])}',
105+
f'in case i,j = {i},{j}: output shape ({sym_shape_str_list[j]}) is not expected {(self.expected[i][j])}',
106106
)
107107

108108
return True

0 commit comments

Comments
 (0)