diff --git a/test/cpp/pir/shape_dialect/CMakeLists.txt b/test/cpp/pir/shape_dialect/CMakeLists.txt index 6aecc07570061e..e6786f9e1083ce 100644 --- a/test/cpp/pir/shape_dialect/CMakeLists.txt +++ b/test/cpp/pir/shape_dialect/CMakeLists.txt @@ -16,10 +16,7 @@ paddle_test( pir gtest) -cc_test( - symbol_dim_expr_test - SRCS symbol_dim_expr_test.cc - DEPS op_dialect_vjp pir gtest) +paddle_test(symbol_dim_expr_test SRCS symbol_dim_expr_test.cc DEPS pir gtest) if(WITH_CINN) paddle_test( diff --git a/test/cpp/pir/shape_dialect/symbol_dim_expr_test.cc b/test/cpp/pir/shape_dialect/symbol_dim_expr_test.cc index 27d71345cc8f75..ef5fe03069e4ab 100644 --- a/test/cpp/pir/shape_dialect/symbol_dim_expr_test.cc +++ b/test/cpp/pir/shape_dialect/symbol_dim_expr_test.cc @@ -56,47 +56,18 @@ TEST(DimExpr, constraint) { out = pd.reshape(y, extend_x) */ TEST(DimExpr, data_shape_expr) { - // 1. Init pir::program and pir::builder - ::pir::IrContext* ctx = ::pir::IrContext::Instance(); - ::pir::Program program(ctx); - ctx->GetOrRegisterDialect(); - ::pir::Builder builder = ::pir::Builder(ctx, program.block()); - - // 2. Show fake network, assume calling x.shape correspond to ShapeOp - const std::vector x_shape = {-1, 2}; - const std::vector y_shape = {1, -1, 2}; - auto x = builder - .Build( - "input_x", x_shape, phi::DataType::FLOAT32, phi::GPUPlace()) - .result(0); - auto y = builder - .Build( - "input_y", y_shape, phi::DataType::FLOAT32, phi::GPUPlace()) - .result(0); - - auto shape_op = builder.Build(x); - ::pir::Value extend_x = shape_op.out(); - paddle::dialect::ReshapeOp reshape_op = - builder.Build(y, extend_x); - ::pir::Value out = reshape_op.out(); - - // 3. Show ideal ShapeOrDataDimExprs of each pir::Value - std::unordered_map value2shape{}; + // Show ideal ShapeOrDataDimExprs of each pir::Value std::vector x_shapes{DimExpr("S0"), DimExpr(2)}; std::vector y_shapes{DimExpr(1), DimExpr("S1"), DimExpr(2)}; // x => {shape: [S0, 2], data: nullopt} - ShapeOrDataDimExprs x_value_shape{x_shapes}; - value2shape.emplace(x, x_value_shape); + ShapeOrDataDimExprs x_data_shape{x_shapes}; // y => {shape: [1, S1, 2], data: nullopt} - ShapeOrDataDimExprs y_value_shape{y_shapes}; - value2shape.emplace(y, y_value_shape); + ShapeOrDataDimExprs y_data_shape{y_shapes}; // extend_x => {shape: [2], data: [S0, 2]} - ShapeOrDataDimExprs extend_x_value_shape = + ShapeOrDataDimExprs extend_x_data_shape = ShapeOrDataDimExprs::MakeConsistentShapeOrData(x_shapes); - value2shape.emplace(extend_x, extend_x_value_shape); // out => {shape: [S0, 2], data: nullopt} ShapeOrDataDimExprs out_value_shape{x_shapes}; - value2shape.emplace(out, out_value_shape); } TEST(Simplify, NumberArithmetic) {