Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions test/cpp/pir/shape_dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@ paddle_test(
pir
gtest)

cc_test(
symbol_dim_expr_test
SRCS symbol_dim_expr_test.cc
DEPS op_dialect_vjp pir gtest)
paddle_test(symbol_dim_expr_test SRCS symbol_dim_expr_test.cc DEPS pir gtest)

if(WITH_CINN)
paddle_test(
Expand Down
37 changes: 4 additions & 33 deletions test/cpp/pir/shape_dialect/symbol_dim_expr_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,47 +56,18 @@ TEST(DimExpr, constraint) {
out = pd.reshape(y, extend_x)
*/
TEST(DimExpr, data_shape_expr) {
// 1. Init pir::program and pir::builder
::pir::IrContext* ctx = ::pir::IrContext::Instance();
::pir::Program program(ctx);
ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>();
::pir::Builder builder = ::pir::Builder(ctx, program.block());

// 2. Show fake network, assume calling x.shape correspond to ShapeOp
const std::vector<int64_t> x_shape = {-1, 2};
const std::vector<int64_t> y_shape = {1, -1, 2};
auto x = builder
.Build<paddle::dialect::DataOp>(
"input_x", x_shape, phi::DataType::FLOAT32, phi::GPUPlace())
.result(0);
auto y = builder
.Build<paddle::dialect::DataOp>(
"input_y", y_shape, phi::DataType::FLOAT32, phi::GPUPlace())
.result(0);

auto shape_op = builder.Build<paddle::dialect::ShapeOp>(x);
::pir::Value extend_x = shape_op.out();
paddle::dialect::ReshapeOp reshape_op =
builder.Build<paddle::dialect::ReshapeOp>(y, extend_x);
::pir::Value out = reshape_op.out();
Comment on lines -59 to -81
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

暂时删掉这部分代码,这部分代码会引入编译符号问题。剩余 code 已经可以展示 ShapeOrData 的用法了


// 3. Show ideal ShapeOrDataDimExprs of each pir::Value
std::unordered_map<pir::Value, ShapeOrDataDimExprs> value2shape{};
// Show ideal ShapeOrDataDimExprs of each pir::Value
std::vector<DimExpr> x_shapes{DimExpr("S0"), DimExpr(2)};
std::vector<DimExpr> y_shapes{DimExpr(1), DimExpr("S1"), DimExpr(2)};
// x => {shape: [S0, 2], data: nullopt}
ShapeOrDataDimExprs x_value_shape{x_shapes};
value2shape.emplace(x, x_value_shape);
ShapeOrDataDimExprs x_data_shape{x_shapes};
// y => {shape: [1, S1, 2], data: nullopt}
ShapeOrDataDimExprs y_value_shape{y_shapes};
value2shape.emplace(y, y_value_shape);
ShapeOrDataDimExprs y_data_shape{y_shapes};
// extend_x => {shape: [2], data: [S0, 2]}
ShapeOrDataDimExprs extend_x_value_shape =
ShapeOrDataDimExprs extend_x_data_shape =
ShapeOrDataDimExprs::MakeConsistentShapeOrData(x_shapes);
value2shape.emplace(extend_x, extend_x_value_shape);
// out => {shape: [S0, 2], data: nullopt}
ShapeOrDataDimExprs out_value_shape{x_shapes};
value2shape.emplace(out, out_value_shape);
}

TEST(Simplify, NumberArithmetic) {
Expand Down