From 566dff648a14b86f71cf2b51db84558b7c6d2d9f Mon Sep 17 00:00:00 2001 From: lanxianghit Date: Sun, 4 Feb 2024 08:35:04 +0000 Subject: [PATCH 1/4] Add InferSymbolicShape for matmul, max --- .../interface/infer_symbolic_shape.cc | 112 +++++++++++++++++- .../cinn/symbolic/test_op_infer_sym_shape.py | 46 +++++++ 2 files changed, 153 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc index 97532183f87a3d..83739164a464d7 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc @@ -20,6 +20,11 @@ #include "paddle/pir/core/builtin_type_interfaces.h" #include "paddle/pir/dialect/shape/ir/shape_attribute.h" +// to make codes shorter +using ShapeOrData = symbol::ShapeOrDataDimExprs; +using TensorExprs = symbol::TensorShapeOrDataDimExprs; +using TensorListExprs = symbol::TensorListShapeOrDataDimExprs; + template struct AttributeTrait; @@ -1108,16 +1113,113 @@ bool ExpandOpInferSymbolicShape( bool MatmulOpInferSymbolicShape( pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " DOES NOT have InferSymbolicShapeInterface!")); + // x_dims can't be const or ref here, in case to be broadcasted + std::vector x_dims = [&] { + std::vector dims; + const auto &x_shape_or_data = + shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); + if (x_shape_or_data.data().has_value()) { + dims = x_shape_or_data.data().value(); + } else { + dims = x_shape_or_data.shape(); + } + return dims; + }(); + + // y_dims can't be const or ref here, in case to be broadcasted + std::vector y_dims = [&] { + std::vector dims; + const auto y_shape_or_data = + shape_analysis->GetShapeOrDataForValue(op->operand_source(1)); + if (y_shape_or_data.data().has_value()) { + dims = y_shape_or_data.data().value(); + } else { + dims = y_shape_or_data.shape(); + } + return dims; + }(); + + size_t ndims_x = x_dims.size(); + size_t ndims_y = y_dims.size(); + + const bool x_broadcasted = [&] { + bool broadcasted = false; + if (ndims_x == 1) { + x_dims.insert(x_dims.begin(), 1); + ndims_x = 2; + broadcasted = true; + } + return broadcasted; + }(); + + const bool y_broadcasted = [&] { + bool broadcasted = false; + if (ndims_y == 1) { + y_dims.emplace_back(1); + ndims_x = 2; + broadcasted = true; + } + return broadcasted; + }(); + + std::vector out_dims; + if (ndims_x > ndims_y) { + out_dims.assign(x_dims.begin(), x_dims.end() - 2); + } else if (ndims_x < ndims_y) { + out_dims.assign(y_dims.begin(), y_dims.end() - 2); + } else { + symbol::DimExprBuilder builder{nullptr}; + for (size_t i = 0; i < ndims_x - 2; ++i) { + out_dims.emplace_back(builder.Broadcast(x_dims[i], y_dims[i])); + } + } + + symbol::DimExpr out_M = + op->attributes().at("transpose_x").dyn_cast().data() + ? x_dims[ndims_x - 1] + : x_dims[ndims_x - 2]; + symbol::DimExpr out_N = + op->attributes().at("transpose_y").dyn_cast().data() + ? y_dims[ndims_y - 2] + : y_dims[ndims_y - 1]; + if (!x_broadcasted) { + out_dims.emplace_back(out_M); + } + if (!y_broadcasted) { + out_dims.emplace_back(out_N); + } + + shape_analysis->SetShapeOrDataForValue(op->result(0), + ShapeOrData{TensorExprs(out_dims)}); + return true; } bool MaxOpInferSymbolicShape(pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " DOES NOT have InferSymbolicShapeInterface!")); - return true; + bool keepdim = + op->attributes().at("keepdim").dyn_cast().data(); + + const std::vector axis = [&] { + pir::Operation *axis_gen_op = op->operand_source(1).defining_op(); + std::vector axis_vec; + if (axis_gen_op->isa()) { + axis_vec = GetVectorAttr( + axis_gen_op->dyn_cast(), "value"); + } else { + // TODO(lanxianghit): there's other source: pir::VectorType, + // paddle::dialect::DenseTensorType, but after PRIM, maybe always + // FullIntArrayOp, to be confirmed + PADDLE_THROW( + phi::errors::Unimplemented("MaxOpInferSymbolicShape: 'axis' only " + "support FullIntArrayOp's result now.")); + } + return axis_vec; + }(); + + bool reduce_all = axis.size() == 0 ? true : false; + + return ReduceInferDim(op, shape_analysis, axis, keepdim, reduce_all); } bool TrilOpInferSymbolicShape(pir::Operation *op, diff --git a/test/ir/pir/cinn/symbolic/test_op_infer_sym_shape.py b/test/ir/pir/cinn/symbolic/test_op_infer_sym_shape.py index d09ba04ff65768..44129d6ec7696b 100644 --- a/test/ir/pir/cinn/symbolic/test_op_infer_sym_shape.py +++ b/test/ir/pir/cinn/symbolic/test_op_infer_sym_shape.py @@ -202,5 +202,51 @@ def test_eval_symbolic(self): return out +class MatmulNet(paddle.nn.Layer): + def __init__(self): + super().__init__() + + def forward(self, x, y): + out = paddle.matmul(x, y) + + return out + + +class TestMatmulOpInferSymbolicShape(TestBase): + def prepare_data(self): + self.x = paddle.rand([1, 3], 'float32') + self.y = paddle.rand([3, 2], 'float32') + + self.expected_sym_shapes = [ + 'shape[S3, S2], data[NULL]', + ] + + def test_eval_symbolic(self): + net = MatmulNet() + + input_spec = [ + InputSpec(shape=[None, None], dtype='float32'), + InputSpec(shape=[None, None], dtype='float32'), + ] + net = apply_to_static(net, False, input_spec) + net.eval() + + # check the infer result + sym_shape_str_list = get_sym_shape_str_for_op( + net, input_spec, 'pd_op.matmul' + ) + np.testing.assert_equal( + len(sym_shape_str_list), len(self.expected_sym_shapes) + ) + for i in range(len(self.expected_sym_shapes)): + np.testing.assert_string_equal( + sym_shape_str_list[i], + self.expected_sym_shapes[i], + 'output shape is not expected!', + ) + out = net(self.x, self.y) + return out + + if __name__ == '__main__': unittest.main() From 28166c997224b5f8e0c92fbf5bb1ac56f68faad6 Mon Sep 17 00:00:00 2001 From: lanxianghit Date: Sun, 4 Feb 2024 08:39:52 +0000 Subject: [PATCH 2/4] rm unused codes --- .../interface/infer_symbolic_shape.cc | 68 ++----------------- 1 file changed, 4 insertions(+), 64 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc index 83739164a464d7..9c1a9a748bc555 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc @@ -83,9 +83,6 @@ bool SameOperandsAndResultShape( symbol::ShapeOrDataDimExprs operand_shape_or_data = shape_analysis->GetShapeOrDataForValue(operand_source); - op->set_attribute("symbolic_shape", - pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), - operand_shape_or_data)); pir::Value res = op->result(0); shape_analysis->SetShapeOrDataForValue(res, operand_shape_or_data); return true; @@ -148,9 +145,7 @@ bool InferSymbolicShapeElementWiseBinary( symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(shapes)}; shape_analysis->SetShapeOrDataForValue(res, shape_data); - op->set_attribute( - "symbolic_shape", - pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data)); + return true; } @@ -189,9 +184,6 @@ bool DataOpInferSymbolicShape(pir::Operation *op, symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(sym_dims)}; - op->set_attribute( - "symbolic_shape", - pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data)); pir::Value res = op->result(0); shape_analysis->SetShapeOrDataForValue(res, shape_data); @@ -268,9 +260,7 @@ bool ShapeOpInferSymbolicShape(pir::Operation *op, sym_shape, operand_shape_or_data.shape())}; shape_analysis->SetShapeOrDataForValue(res, shape_or_data); - op->set_attribute("symbolic_shape", - pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), - shape_or_data)); + return true; } @@ -310,9 +300,6 @@ bool StackOpInferSymbolicShape(pir::Operation *op, symbol::ShapeOrDataDimExprs shape_data( symbol::TensorShapeOrDataDimExprs(out_dims, out_dims_data)); - op->set_attribute( - "symbolic_shape", - pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data)); pir::Value res = op->result(0); shape_analysis->SetShapeOrDataForValue(res, shape_data); return true; @@ -373,9 +360,7 @@ bool ReduceInferDim(pir::Operation *op, pir::Value res = op->result(0); symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(shapes)}; - op->set_attribute( - "symbolic_shape", - pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data)); + shape_analysis->SetShapeOrDataForValue(res, shape_data); return true; } @@ -446,9 +431,6 @@ bool ReshapeOpInferSymbolicShape( symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(out_dims)}; - op->set_attribute( - "symbolic_shape", - pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data)); pir::Value res0 = op->result(0); pir::Value res1 = op->result(1); @@ -481,10 +463,6 @@ bool FullIntArrayOpInferSymbolicShape( symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(shape, data)}; - op->set_attribute( - "symbolic_shape", - pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data)); - pir::Value res = op->result(0); shape_analysis->SetShapeOrDataForValue(res, shape_data); return true; @@ -542,10 +520,6 @@ bool SliceOpInferSymbolicShape(pir::Operation *op, symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(sym_shape, out_data)}; - op->set_attribute( - "symbolic_shape", - pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data)); - shape_analysis->SetShapeOrDataForValue(res, shape_data); return true; } @@ -585,10 +559,6 @@ bool FullOpInferSymbolicShape(pir::Operation *op, symbol::TensorShapeOrDataDimExprs(sym_shape)}; shape_data.SetData(sym_data); - op->set_attribute( - "symbolic_shape", - pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data)); - pir::Value res = op->result(0); shape_analysis->SetShapeOrDataForValue(res, shape_data); return true; @@ -634,10 +604,6 @@ bool ConcatOpInferSymbolicShape( symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(out_dims)}; - op->set_attribute( - "symbolic_shape", - pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data)); - pir::Value res = op->result(0); shape_analysis->SetShapeOrDataForValue(res, shape_data); @@ -690,10 +656,6 @@ bool GatherNdOpInferSymbolicShape( symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(result_sym_dims)}; - op->set_attribute( - "symbolic_shape", - pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data)); - pir::Value res = op->result(0); shape_analysis->SetShapeOrDataForValue(res, shape_data); @@ -817,10 +779,6 @@ bool SqueezeOpInferSymbolicShape( symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(output_shape_sym)}; - op->set_attribute( - "symbolic_shape", - pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data)); - pir::Value res = op->result(0); shape_analysis->SetShapeOrDataForValue(res, shape_data); @@ -896,10 +854,6 @@ bool UnsqueezeOpInferSymbolicShape( symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(result_sym_dims)}; - op->set_attribute( - "symbolic_shape", - pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data)); - pir::Value res = op->result(0); shape_analysis->SetShapeOrDataForValue(res, shape_data); @@ -954,10 +908,6 @@ bool TileOpInferSymbolicShape(pir::Operation *op, symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(out_shape)}; - op->set_attribute( - "symbolic_shape", - pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data)); - pir::Value res = op->result(0); shape_analysis->SetShapeOrDataForValue(res, shape_data); @@ -1291,10 +1241,6 @@ bool SliceOpInferSymbolicShape(pir::Operation *op, }; symbol::ShapeOrDataDimExprs shape_data{GetOutDimExprs()}; - op->set_attribute( - "symbolic_shape", - pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data)); - shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); return true; } @@ -1341,10 +1287,6 @@ bool ConcatOpInferSymbolicShape( symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(GetOutDimExprs())}; - op->set_attribute( - "symbolic_shape", - pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data)); - shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); return true; } @@ -1394,9 +1336,7 @@ bool ReshapeOpInferSymbolicShape( symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(out_dims)}; shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); - op->set_attribute( - "symbolic_shape", - pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data)); + return true; } From 75436a10abd5d65e16380fb8073f8bcc219c8899 Mon Sep 17 00:00:00 2001 From: lanxianghit Date: Sun, 4 Feb 2024 08:42:57 +0000 Subject: [PATCH 3/4] make err msg more clear --- .../operator/interface/infer_symbolic_shape.cc | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc index 9c1a9a748bc555..b981b2c69d4355 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc @@ -669,7 +669,7 @@ bool PowOpInferSymbolicShape(pir::Operation *op, bool Pow_OpInferSymbolicShape(pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " DOES NOT have InferSymbolicShapeInterface!")); + op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); return PowOpInferSymbolicShape(op, shape_analysis); } @@ -917,7 +917,7 @@ bool TileOpInferSymbolicShape(pir::Operation *op, bool TransposeOpInferSymbolicShape( pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " DOES NOT have InferSymbolicShapeInterface!")); + op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); return true; } bool Transpose_OpInferSymbolicShape( @@ -1050,14 +1050,14 @@ bool EmbeddingOpInferSymbolicShape( bool SparseWeightEmbeddingOpInferSymbolicShape( pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " DOES NOT have InferSymbolicShapeInterface!")); + op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); return true; } bool ExpandOpInferSymbolicShape( pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " DOES NOT have InferSymbolicShapeInterface!")); + op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); return true; } @@ -1175,7 +1175,7 @@ bool MaxOpInferSymbolicShape(pir::Operation *op, bool TrilOpInferSymbolicShape(pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " DOES NOT have InferSymbolicShapeInterface!")); + op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); return true; } @@ -1187,7 +1187,7 @@ bool Tril_OpInferSymbolicShape(pir::Operation *op, bool WhereOpInferSymbolicShape(pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " DOES NOT have InferSymbolicShapeInterface!")); + op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); return true; } From a7f55c34deded3c468956e37344cd1614d355c89 Mon Sep 17 00:00:00 2001 From: lanxianghit Date: Sun, 4 Feb 2024 12:19:01 +0000 Subject: [PATCH 4/4] bug fix --- .../interface/infer_symbolic_shape.cc | 2 +- .../cinn/symbolic/test_op_infer_sym_shape.py | 150 +++++++++++++++--- 2 files changed, 126 insertions(+), 26 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc index b981b2c69d4355..469db025ce6505 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc @@ -1106,7 +1106,7 @@ bool MatmulOpInferSymbolicShape( bool broadcasted = false; if (ndims_y == 1) { y_dims.emplace_back(1); - ndims_x = 2; + ndims_y = 2; broadcasted = true; } return broadcasted; diff --git a/test/ir/pir/cinn/symbolic/test_op_infer_sym_shape.py b/test/ir/pir/cinn/symbolic/test_op_infer_sym_shape.py index 44129d6ec7696b..e20f64b5ee5083 100644 --- a/test/ir/pir/cinn/symbolic/test_op_infer_sym_shape.py +++ b/test/ir/pir/cinn/symbolic/test_op_infer_sym_shape.py @@ -206,47 +206,147 @@ class MatmulNet(paddle.nn.Layer): def __init__(self): super().__init__() - def forward(self, x, y): - out = paddle.matmul(x, y) + def forward(self, x, y, trans_x, trans_y): + out = paddle.matmul(x, y, trans_x, trans_y) return out class TestMatmulOpInferSymbolicShape(TestBase): def prepare_data(self): - self.x = paddle.rand([1, 3], 'float32') - self.y = paddle.rand([3, 2], 'float32') + self.cases = [ + # [x, y, trans_x, trans_y] + [np.random.rand(1, 3), np.random.rand(3, 2), False, False], + # with broadcast + [np.random.rand(10), np.random.rand(10), False, False], # [] + [np.random.rand(10, 5), np.random.rand(5), False, False], # [10] + [ + np.random.rand(10, 5, 2), + np.random.rand(2), + False, + False, + ], # [10, 5] + [ + np.random.rand(10, 5, 2), + np.random.rand(10, 2, 5), + False, + False, + ], # [10, 5, 5] + [ + np.random.rand(10, 1, 5, 2), + np.random.rand(1, 3, 2, 5), + False, + False, + ], # [10, 3, 5, 5] + # with transpose + [np.random.rand(3, 5), np.random.rand(3, 2), True, False], # [5, 2] + [np.random.rand(3, 5), np.random.rand(4, 5), False, True], # [3, 4] + ] - self.expected_sym_shapes = [ - 'shape[S3, S2], data[NULL]', + self.expected = [ + 'shape[S0, S3], data[NULL]', + # with broadcast + 'shape[], data[NULL]', + 'shape[S0], data[NULL]', + 'shape[S0, S1], data[NULL]', + 'shape[Broadcast(S0, S3), S1, S5], data[NULL]', + 'shape[Broadcast(S0, S4), Broadcast(S1, S5), S2, S7], data[NULL]', + # with transpose + 'shape[S1, S3], data[NULL]', + 'shape[S0, S2], data[NULL]', ] def test_eval_symbolic(self): net = MatmulNet() - input_spec = [ - InputSpec(shape=[None, None], dtype='float32'), - InputSpec(shape=[None, None], dtype='float32'), - ] - net = apply_to_static(net, False, input_spec) - net.eval() + for i in range(len(self.cases)): + x, y, trans_x, trans_y = self.cases[i] + x_spec = InputSpec( + shape=[None for index in range(len(x.shape))], dtype='float32' + ) + y_spec = InputSpec( + shape=[None for index in range(len(y.shape))], dtype='float32' + ) - # check the infer result - sym_shape_str_list = get_sym_shape_str_for_op( - net, input_spec, 'pd_op.matmul' - ) - np.testing.assert_equal( - len(sym_shape_str_list), len(self.expected_sym_shapes) - ) - for i in range(len(self.expected_sym_shapes)): - np.testing.assert_string_equal( - sym_shape_str_list[i], - self.expected_sym_shapes[i], - 'output shape is not expected!', + input_spec = [x_spec, y_spec, trans_x, trans_y] + net = apply_to_static(net, False, input_spec) + net.eval() + + # check the infer result + sym_shape_str_list = get_sym_shape_str_for_op( + net, input_spec, 'pd_op.matmul' ) - out = net(self.x, self.y) + np.testing.assert_equal(len(sym_shape_str_list), 1) + np.testing.assert_equal( + sym_shape_str_list[0].find(self.expected[i]), + 0, + f'in case i = {i}: output shape ({sym_shape_str_list[0]}) is not expected {(self.expected[i])}', + ) + + return True + + +class MaxNet(paddle.nn.Layer): + def __init__(self): + super().__init__() + + def forward(self, x): + out = paddle.max(x) + out = paddle.max(x, 0) + out = paddle.max(x, 1) + out = paddle.max(x, -1) + out = paddle.max(x, -2) + + # keepdim=True + # out = paddle.max(x, 0, True) + return out +class TestMaxOpInferSymbolicShape(TestBase): + def prepare_data(self): + self.cases = [np.random.rand(2, 4)] + + self.expected = [ + [ + 'shape[], data[NULL]', + 'shape[S1], data[NULL]', + 'shape[S0], data[NULL]', + 'shape[S0], data[NULL]', + 'shape[S1], data[NULL]', + # 'shape[1, S1], data[NULL]', + ] + ] + + def test_eval_symbolic(self): + net = MaxNet() + + for i in range(len(self.cases)): + x = self.cases[i] + x_spec = InputSpec( + shape=[None for index in range(len(x.shape))], dtype='float32' + ) + + input_spec = [x_spec] + net = apply_to_static(net, False, input_spec) + net.eval() + + # check the infer result + sym_shape_str_list = get_sym_shape_str_for_op( + net, input_spec, 'pd_op.max' + ) + np.testing.assert_equal( + len(sym_shape_str_list), len(self.expected[i]) + ) + for j in range(len(sym_shape_str_list)): + np.testing.assert_equal( + sym_shape_str_list[j].find(self.expected[i][j]), + 0, + f'in case i,j = {i},{j}: output shape ({sym_shape_str_list[0]}) is not expected {(self.expected[i][j])}', + ) + + return True + + if __name__ == '__main__': unittest.main()