Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1167,8 +1167,98 @@ bool ExpandAsOpInferSymbolicShape(

bool SplitOpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis) {
PADDLE_THROW(phi::errors::Unimplemented(
op->name() + " 's InferSymbolicShape interface is NOT implemented now."));
// input
const auto &x_shape_or_data =
shape_analysis->GetShapeOrDataForValue(op->operand_source(0));
PADDLE_ENFORCE_EQ(x_shape_or_data.data().has_value(),
false,
phi::errors::InvalidArgument(
"InferSymbolicShape of SplitOp only support input with "
"value now."));
const auto &x_dims_sym = x_shape_or_data.shape();

// axis
CHECK(op->operand_source(2).defining_op()->isa<paddle::dialect::FullOp>());

int64_t axis = op->operand_source(2)
.defining_op<paddle::dialect::FullOp>()
.attributes()
.at("value")
.dyn_cast<paddle::dialect::ScalarAttribute>()
.data()
.to<int64_t>();

// sections
const std::vector<symbol::DimExpr> &sections_sym = [&] {
const auto &sections_shape_or_data =
shape_analysis->GetShapeOrDataForValue(op->operand_source(1));
std::vector<symbol::DimExpr> sections_sym;
if (sections_shape_or_data.data().has_value()) {
sections_sym = sections_shape_or_data.data().value();
} else {
sections_sym = sections_shape_or_data.shape();
}
return sections_sym;
}();

// output
const symbol::TensorListShapeOrDataDimExprs &output_shape_data_list = [&] {
const auto &GetSum = [&](const auto &dim_exprs, const auto &Filter) {
symbol::DimExpr sum{0};
for (const auto &dim_expr : dim_exprs) {
if (Filter(dim_expr)) {
sum = sum + dim_expr;
}
}
return sum;
};
const auto &All = [&](const auto &dim_exprs, const auto &Cond) {
for (const auto &dim_expr : dim_exprs) {
if (!Cond(dim_expr)) {
return false;
}
}
return true;
};
const auto &IsNotMinusOne = [&](const symbol::DimExpr &dim_expr) {
if (dim_expr.isa<int64_t>()) {
return dim_expr.dyn_cast<int64_t>() != static_cast<int64_t>(-1);
}
return true;
};
const auto &sum_exclude_minus_one = GetSum(sections_sym, IsNotMinusOne);

const bool &all_sections_sym_not_minus_one =
All(sections_sym, IsNotMinusOne);
if (all_sections_sym_not_minus_one) {
shape_analysis->CreateDimExprBuilder().CstrEq(x_dims_sym[axis],
sum_exclude_minus_one);
}

symbol::TensorListShapeOrDataDimExprs shape_data_list;
std::vector<symbol::DimExpr> output_dims_sym = x_dims_sym;
if (!all_sections_sym_not_minus_one && sections_sym.size() == 1) {
VLOG(3) << "[SplitOp]-1 is the only split section. The output shape is "
"identical to the input shape.";
shape_data_list.push_back(
symbol::TensorShapeOrDataDimExprs(output_dims_sym));
return shape_data_list;
}
for (uint32_t idx = 0; idx < sections_sym.size(); idx++) {
const auto &section_sym = sections_sym[idx];
output_dims_sym[axis] = IsNotMinusOne(section_sym)
? section_sym
: x_dims_sym[axis] - sum_exclude_minus_one;

shape_data_list.push_back(
symbol::TensorShapeOrDataDimExprs(output_dims_sym));
}
return shape_data_list;
}();

shape_analysis->SetShapeOrDataForValue(
op->result(0), symbol::ShapeOrDataDimExprs{output_shape_data_list});

return true;
}

Expand Down
31 changes: 31 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/op_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,32 @@ struct ShadowOutputOpInferSymbolicShapeInterfaceModel
: InferSymbolicShapeInterface::Concept(InferSymbolicShape) {}
};

struct SplitOpInferSymbolicShapeInterfaceModel
: public InferSymbolicShapeInterface::Concept {
static inline bool InferSymbolicShape(
pir::Operation* op, pir::ShapeConstraintIRAnalysis* shape_analysis) {
const auto& shape_data_list =
shape_analysis->GetShapeOrDataForValue(op->operand_source(0))
.dyn_cast<symbol::TensorListShapeOrDataDimExprs>();

for (uint32_t rst_idx = 0; rst_idx < op->num_results(); rst_idx++) {
PADDLE_ENFORCE_EQ(
shape_data_list[rst_idx].data().has_value(),
false,
paddle::platform::errors::InvalidArgument(
"Currently InferSymbolicShape of SplitOp only support "
"input without value."));
shape_analysis->SetShapeOrDataForValue(
op->result(rst_idx),
symbol::ShapeOrDataDimExprs{shape_data_list[rst_idx]});
}
return true;
}

SplitOpInferSymbolicShapeInterfaceModel()
: InferSymbolicShapeInterface::Concept(InferSymbolicShape) {}
};

struct YieldOpInferSymbolicShapeInterfaceModel
: public InferSymbolicShapeInterface::Concept {
static inline bool InferSymbolicShape(
Expand Down Expand Up @@ -196,6 +222,11 @@ OperatorDialect::OperatorDialect(pir::IrContext* ctx)
InferSymbolicShapeInterface,
ShadowOutputOpInferSymbolicShapeInterfaceModel>()));

info = ctx->GetRegisteredOpInfo(pir::SplitOp::name());
info.AttachInterface(std::move(
pir::InterfaceValue::Get<InferSymbolicShapeInterface,
SplitOpInferSymbolicShapeInterfaceModel>()));

info = ctx->GetRegisteredOpInfo(pir::YieldOp::name());
info.AttachInterface(std::move(
pir::InterfaceValue::Get<InferSymbolicShapeInterface,
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/api/yaml/legacy_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1078,6 +1078,7 @@
kernel :
func : split
backward : split_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : split_with_num
args : (Tensor x, int num, Scalar(int) axis)
Expand Down
81 changes: 77 additions & 4 deletions test/ir/pir/cinn/symbolic/test_op_infer_sym_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def test_eval_symbolic(self):
np.testing.assert_equal(
sym_shape_str_list[j].find(self.expected[i][j]),
0,
f'in case i,j = {i},{j}: output shape ({sym_shape_str_list[0]}) is not expected {(self.expected[i][j])}',
f'in case i,j = {i},{j}: output shape ({sym_shape_str_list[j]}) is not expected {(self.expected[i][j])}',
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在写unittest的时候发现一些failure log打印的东西不太对,修了一下。

)

return True
Expand Down Expand Up @@ -403,7 +403,7 @@ def test_eval_symbolic(self):
np.testing.assert_equal(
sym_shape_str_list[j].find(self.expected[i][j]),
0,
f'in case i,j = {i},{j}: output shape ({sym_shape_str_list[0]}) is not expected {(self.expected[i][j])}',
f'in case i,j = {i},{j}: output shape ({sym_shape_str_list[j]}) is not expected {(self.expected[i][j])}',
)

return True
Expand Down Expand Up @@ -453,7 +453,7 @@ def test_eval_symbolic(self):
np.testing.assert_equal(
sym_shape_str_list[j].find(self.expected[i][j]),
0,
f'in case i,j = {i},{j}: output shape ({sym_shape_str_list[0]}) is not expected {(self.expected[i][j])}',
f'in case i,j = {i},{j}: output shape ({sym_shape_str_list[j]}) is not expected {(self.expected[i][j])}',
)

return True
Expand Down Expand Up @@ -511,11 +511,84 @@ def test_eval_symbolic(self):
np.testing.assert_equal(
sym_shape_str_list[j].find(self.expected[i][j]),
0,
f'in case i,j = {i},{j}: output shape ({sym_shape_str_list[0]}) is not expected {(self.expected[i][j])}',
f'in case i,j = {i},{j}: output shape ({sym_shape_str_list[j]}) is not expected {(self.expected[i][j])}',
)

return True


class SplitNet(paddle.nn.Layer):
def __init__(self):
super().__init__()

def forward(self, x):
out = paddle.split(x, [-1], axis=1)
out = paddle.split(x, [1, 2, -1], axis=1)
out = paddle.split(x, [1, -1], axis=1)
out = paddle.split(x, [1, 2, 3], axis=1)
out = paddle.split(x, [1, 2, x.shape[1]], axis=1)

out = x.split([-1], axis=1)
out = x.split([1, 2, -1], axis=1)
out = x.split([1, -1], axis=1)
out = x.split([1, 2, 3], axis=1)
out = x.split([1, 2, x.shape[1]], axis=1)

return out


class TestSplitOpInferSymbolicShape(TestBase):
def prepare_data(self):
self.cases = [np.random.rand(4, 6, 5)]

self.expected = [
[
'shape[S0, S1, S2], data[NULL]',
'shape[S0, 1, S2], data[NULL], shape[S0, 2, S2], data[NULL], shape[S0, Add(S1, -3), S2], data[NULL]',
'shape[S0, 1, S2], data[NULL], shape[S0, Add(S1, -1), S2], data[NULL]',
'shape[S0, 1, S2], data[NULL], shape[S0, 2, S2], data[NULL], shape[S0, 3, S2], data[NULL]',
'shape[S0, 1, S2], data[NULL], shape[S0, 2, S2], data[NULL], shape[S0, S1, S2], data[NULL]',
'shape[S0, S1, S2], data[NULL]',
'shape[S0, 1, S2], data[NULL], shape[S0, 2, S2], data[NULL], shape[S0, Add(S1, -3), S2], data[NULL]',
'shape[S0, 1, S2], data[NULL], shape[S0, Add(S1, -1), S2], data[NULL]',
'shape[S0, 1, S2], data[NULL], shape[S0, 2, S2], data[NULL], shape[S0, 3, S2], data[NULL]',
'shape[S0, 1, S2], data[NULL], shape[S0, 2, S2], data[NULL], shape[S0, S1, S2], data[NULL]',
]
]

def test_eval_symbolic(self):
net = SplitNet()

for i in range(len(self.cases)):
x = self.cases[i]
x_spec = InputSpec(
shape=[None for index in range(len(x.shape))], dtype='float32'
)

input_spec = [x_spec]
net = apply_to_static(net, False, input_spec)
net.eval()

# check the infer result
sym_shape_str_list = get_sym_shape_str_for_op(
net, input_spec, 'pd_op.split'
)
np.testing.assert_equal(
len(sym_shape_str_list), len(self.expected[i])
)
for j in range(len(sym_shape_str_list)):
np.testing.assert_equal(
sym_shape_str_list[j].find(self.expected[i][j]),
0,
f'in case i,j = {i},{j}: output shape ({sym_shape_str_list[j]}) is not expected {(self.expected[i][j])}',
)

# TODO(fty1777): Add builtin.split op infer symbolic shape test
# Not added because attribute `sym_shape_str` does not support multi-output op now.
# See also: paddle/fluid/pir/transforms/shape_optimization_pass.cc:144.

return True


if __name__ == '__main__':
unittest.main()
2 changes: 1 addition & 1 deletion test/ir/pir/cinn/symbolic/test_unary_op_infer_sym_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def test_eval_symbolic(self):
np.testing.assert_equal(
sym_shape_str_list[j].find(self.expected[i][j]),
0,
f'in case i,j = {i},{j}: output shape ({sym_shape_str_list[0]}) is not expected {(self.expected[i][j])}',
f'in case i,j = {i},{j}: output shape ({sym_shape_str_list[j]}) is not expected {(self.expected[i][j])}',
)

return True
Expand Down