diff --git a/paddle/pir/dialect/shape/utils/dim_expr_builder.cc b/paddle/pir/dialect/shape/utils/dim_expr_builder.cc index ac1b1c651626ea..19aad0807fc67c 100644 --- a/paddle/pir/dialect/shape/utils/dim_expr_builder.cc +++ b/paddle/pir/dialect/shape/utils/dim_expr_builder.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/pir/dialect/shape/utils/dim_expr_builder.h" +#include "paddle/common/enforce.h" namespace symbol { @@ -20,10 +21,10 @@ using BroadcastDimExpr = Broadcast; using MinDimExpr = Min; using MaxDimExpr = Max; -DimExpr DimExprBuilder::ConstSize(std::int64_t dim) { SYMBOL_NOT_IMPLEMENTED; } +DimExpr DimExprBuilder::ConstSize(std::int64_t dim) { return DimExpr(dim); } DimExpr DimExprBuilder::Symbol(const std::string& symbol_name) { - SYMBOL_NOT_IMPLEMENTED; + return DimExpr(symbol_name); } DimExpr DimExprBuilder::Add(const DimExpr& lhs, const DimExpr& rhs) { @@ -56,7 +57,12 @@ DimExpr DimExprBuilder::Broadcast(const DimExpr& lhs, const DimExpr& rhs) { std::vector DimExprBuilder::ConstShape( const std::vector& dims) { - SYMBOL_NOT_IMPLEMENTED; + std::vector ret{}; + ret.reserve(dims.size()); + for (std::int64_t dim : dims) { + ret.emplace_back(dim); + } + return ret; } void DimExprBuilder::CstrBroadcastable(const DimExpr& lhs, const DimExpr& rhs) { @@ -74,21 +80,43 @@ void DimExprBuilder::CstrEq(const DimExpr& lhs, const DimExpr& rhs) { void DimExprBuilder::CstrEq(const std::vector& lhs, const std::vector& rhs) { - SYMBOL_NOT_IMPLEMENTED; + IR_ENFORCE(lhs.size() == rhs.size(), + "Please make sure input sizes are equal, " + "lhs.size() = %d, rhs.size() = %d.", + lhs.size(), + rhs.size()); + for (std::size_t i = 0; i < lhs.size(); ++i) { + CstrEq(lhs.at(i), rhs.at(i)); + } } std::vector DimExprBuilder::Concat(const std::vector& lhs, const std::vector& rhs) { - SYMBOL_NOT_IMPLEMENTED; + std::vector ret{}; + const auto& EmplaceDimExpr = [&](const std::vector& exprs) { + for (const auto& expr : exprs) { + ret.emplace_back(expr); + } + }; + EmplaceDimExpr(lhs); + EmplaceDimExpr(rhs); + return ret; } std::pair, std::vector> DimExprBuilder::SplitAt( - const std::vector, int index) { - SYMBOL_NOT_IMPLEMENTED; + const std::vector dim_exprs, int index) { + IR_ENFORCE(index > 0 && index < static_cast(dim_exprs.size()), + "Index invalid, index = %d, dim_exprs.size() = %d. Please check " + "your inputs.", + index, + dim_exprs.size()); + std::vector lhs(dim_exprs.begin(), dim_exprs.begin() + index); + std::vector rhs(dim_exprs.begin() + index, dim_exprs.end()); + return std::make_pair(lhs, rhs); } const std::vector& DimExprBuilder::constraints() const { - SYMBOL_NOT_IMPLEMENTED; + return *constraints_; } } // namespace symbol diff --git a/test/cpp/pir/shape_dialect/symbol_dim_expr_test.cc b/test/cpp/pir/shape_dialect/symbol_dim_expr_test.cc index 0026354648bf93..6454cecb3192c4 100644 --- a/test/cpp/pir/shape_dialect/symbol_dim_expr_test.cc +++ b/test/cpp/pir/shape_dialect/symbol_dim_expr_test.cc @@ -47,6 +47,16 @@ TEST(DimExpr, Constraint) { DimExpr sym1 = DimExpr("S1"); builder.CstrEq(sym0, sym1); ASSERT_EQ(static_cast(constraints.size()), 1); + std::vector lhs = builder.ConstShape({1, 2, 3}); + std::vector rhs = builder.ConstShape({1, 2, 3}); + std::pair, std::vector> expr_pair = + builder.SplitAt(rhs, 1); + ASSERT_EQ(static_cast(expr_pair.first.size()), 1); + ASSERT_EQ(static_cast(expr_pair.second.size()), 2); + std::vector merged = + builder.Concat(expr_pair.first, expr_pair.second); + builder.CstrEq(lhs, merged); + ASSERT_EQ(static_cast(constraints.size()), 4); } /*