diff --git a/paddle/cinn/common/integer_set.cc b/paddle/cinn/common/integer_set.cc index 762c273caef7c5..1887238c2eb4a8 100644 --- a/paddle/cinn/common/integer_set.cc +++ b/paddle/cinn/common/integer_set.cc @@ -13,6 +13,8 @@ // limitations under the License. #include "paddle/cinn/common/integer_set.h" + +#include "paddle/cinn/common/arithmatic.h" #include "paddle/cinn/ir/ir_mutator.h" #include "paddle/cinn/ir/op/ir_operators.h" #include "paddle/cinn/ir/utils/ir_copy.h" @@ -164,11 +166,115 @@ std::optional SymbolicExprAnalyzer::ProveLT(const ir::Expr& lhs, return ProveGT(rhs, lhs); } +// Tell whether lhs can be divisible by rhs, lhs must be a pure math expression +// and rhs must be a var +std::optional SymbolicExprAnalyzer::ProveDivisible( + const ir::Expr& lhs, const ir::Expr& rhs) const { + CHECK(rhs.is_var()) << "Rhs in ProveDivisible must be a var temporarily!\n"; + CHECK(lhs.defined()); + CHECK(rhs.defined()); + CHECK(cinn::common::IsPureMath(lhs)); + + ir::Expr lhs_copy = ir::ir_utils::IRCopy(lhs); + if (cinn::common::is_zero(lhs_copy)) return true; + + auto OptionalAnd = [](const std::optional& lhs, + const std::optional& rhs) -> std::optional { + if (lhs.has_value() && rhs.has_value()) { + return lhs.value() && rhs.value(); + } else { + return std::nullopt; + } + }; + auto OptionalOr = [](const std::optional& lhs, + const std::optional& rhs) -> std::optional { + if (lhs.has_value() && rhs.has_value()) { + return lhs.value() || rhs.value(); + } else if ((!lhs.has_value()) && (!rhs.has_value())) { + return std::nullopt; + } else if (lhs.has_value() && (!rhs.has_value())) { + return lhs.value() ? std::optional(lhs.value()) + : std::optional(std::nullopt); + } else { + return rhs.value() ? std::optional(rhs.value()) + : std::optional(std::nullopt); + } + }; + + std::vector ops{}; + std::optional res = std::nullopt; + ir::Expr zero(0); + ir::Expr tmp_expr; + + auto is_ge = ProveGE(lhs, rhs); + + switch (lhs.node_type()) { + case cinn::ir::IrNodeTy::_Var_: + return ProveEQ(lhs, rhs); + case cinn::ir::IrNodeTy::IntImm: + return false; + case cinn::ir::IrNodeTy::Sum: + res = true; + ops = lhs.As()->operands(); + CHECK(!ops.empty()); + std::for_each(ops.begin(), ops.end(), [&](const ir::Expr& expr) { + res = OptionalAnd(res, this->ProveDivisible(expr, rhs)); + }); + res = OptionalAnd(res, is_ge); + return res; + case cinn::ir::IrNodeTy::Product: + res = false; + ops = lhs.As()->operands(); + CHECK(!ops.empty()); + std::for_each(ops.begin(), ops.end(), [&](const ir::Expr& expr) { + res = OptionalOr(res, this->ProveDivisible(expr, rhs)); + if (res.has_value() && res.value()) return; + }); + res = OptionalAnd(res, is_ge); + return res; + case cinn::ir::IrNodeTy::FracOp: + tmp_expr = cinn::common::AutoSimplify(lhs); + if (tmp_expr.node_type() == cinn::ir::IrNodeTy::FracOp) + return std::nullopt; + return OptionalAnd(ProveDivisible(tmp_expr, rhs), is_ge); + case cinn::ir::IrNodeTy::FloatImm: + return false; + case cinn::ir::IrNodeTy::Add: + return OptionalAnd( + OptionalAnd(ProveDivisible(lhs.As()->a(), rhs), + ProveDivisible(lhs.As()->b(), rhs)), + is_ge); + case cinn::ir::IrNodeTy::Sub: + return OptionalAnd( + OptionalAnd(ProveDivisible(lhs.As()->a(), rhs), + ProveDivisible(lhs.As()->b(), rhs)), + is_ge); + case cinn::ir::IrNodeTy::Div: + tmp_expr = cinn::common::AutoSimplify(lhs); + if (tmp_expr.node_type() == cinn::ir::IrNodeTy::Div) return std::nullopt; + return OptionalAnd(ProveDivisible(tmp_expr, rhs), is_ge); + case cinn::ir::IrNodeTy::Mul: + return OptionalAnd( + OptionalOr(ProveDivisible(lhs.As()->a(), rhs), + ProveDivisible(lhs.As()->b(), rhs)), + is_ge); + case cinn::ir::IrNodeTy::Mod: + return false; + case cinn::ir::IrNodeTy::Minus: + return ProveDivisible(lhs.As()->v(), rhs); + default: + LOG(FATAL) << "Not supported yet!"; + break; + } +} + class BoundReplacer : public ir::IRMutator<> { public: explicit BoundReplacer(const cas_intervals_t& var_intervals, bool is_lower_bound) - : var_intervals_(var_intervals), sign_(is_lower_bound) {} + : var_intervals_(var_intervals), + sign_(is_lower_bound), + var_visited_({}) {} void operator()(ir::Expr* expr) { IRMutator::Visit(expr, expr); } @@ -183,10 +289,16 @@ class BoundReplacer : public ir::IRMutator<> { upper_bound = interval.e_r.defined() ? interval.e_r : ir::Expr(interval.r); } - if (sign_) { - *op = ir::ir_utils::IRCopy(lower_bound); + if (!var_visited_.count(var->name)) { + if (sign_) { + *op = ir::ir_utils::IRCopy(lower_bound); + var_visited_.insert({var->name, lower_bound}); + } else { + *op = ir::ir_utils::IRCopy(upper_bound); + var_visited_.insert({var->name, upper_bound}); + } } else { - *op = ir::ir_utils::IRCopy(upper_bound); + *op = ir::ir_utils::IRCopy(var_visited_.at(var->name)); } } @@ -248,6 +360,7 @@ class BoundReplacer : public ir::IRMutator<> { private: const cas_intervals_t& var_intervals_; + std::unordered_map var_visited_; // Determine replacing with upper or lower bound, // True means lower bound and False means upper bound. bool sign_; diff --git a/paddle/cinn/common/integer_set.h b/paddle/cinn/common/integer_set.h index e0f23da2e744f8..6d095b12083f11 100644 --- a/paddle/cinn/common/integer_set.h +++ b/paddle/cinn/common/integer_set.h @@ -41,6 +41,8 @@ class SymbolicExprAnalyzer { std::optional ProveLE(const ir::Expr& lhs, const ir::Expr& rhs) const; std::optional ProveGT(const ir::Expr& lhs, const ir::Expr& rhs) const; std::optional ProveLT(const ir::Expr& lhs, const ir::Expr& rhs) const; + std::optional ProveDivisible(const ir::Expr& lhs, + const ir::Expr& rhs) const; ir::Expr LowerBound(const ir::Expr& expr) const; ir::Expr UpperBound(const ir::Expr& expr) const; diff --git a/paddle/cinn/common/integer_set_test.cc b/paddle/cinn/common/integer_set_test.cc index 23406ec2f770ea..6d57f2dd0ed257 100644 --- a/paddle/cinn/common/integer_set_test.cc +++ b/paddle/cinn/common/integer_set_test.cc @@ -136,6 +136,50 @@ TEST_F(TestSymbolicExprAnalyzer, compare) { analyzer.Prove(e3 < e4).value()); } +TEST_F(TestSymbolicExprAnalyzer, Divisible) { + auto x = ir::Var(ir::Expr(1), ir::Expr(7), "x"); + auto y = ir::Var(ir::Expr(1), ir::Expr(15), "y"); + auto S = ir::Var(ir::Expr(16), ir::Expr(256), "S"); + + cas_intervals_t divisible_var_intervals = { + {"x", CasInterval(x->lower_bound, x->upper_bound)}, + {"y", CasInterval(y->lower_bound, y->upper_bound)}, + {"S", CasInterval(S->lower_bound, S->upper_bound)}, + }; + SymbolicExprAnalyzer divisible_analyzer{divisible_var_intervals}; + + // case 1 + ir::Expr e1 = 4 * x + 2 * y * x; + ir::Expr e2 = x; + ir::Expr e3 = y; + + EXPECT_TRUE(divisible_analyzer.ProveDivisible(e1, e2).value_or(false)); + EXPECT_FALSE(divisible_analyzer.ProveDivisible(e1, e3).value_or(false)); + + // case 2 + ir::Expr e4 = y + y * x + 4 * y - x * y; + + EXPECT_TRUE(divisible_analyzer.ProveDivisible(e4, e3).value_or(false)); + EXPECT_FALSE(divisible_analyzer.ProveDivisible(e4, e2).value_or(false)); + + // case 3 + ir::Expr e5 = x / y + x + y; + + EXPECT_FALSE(divisible_analyzer.ProveDivisible(e5, e3).value_or(false)); + EXPECT_FALSE(divisible_analyzer.ProveDivisible(e5, e2).value_or(false)); + + // case 4 + ir::Expr e6 = S * x / 4 + x * y; + + EXPECT_FALSE(divisible_analyzer.ProveDivisible(e6, e2).value_or(false)); + EXPECT_FALSE(divisible_analyzer.ProveDivisible(e6, e3).value_or(false)); + + ir::Expr e7 = 16 * x / 4 + x * y; + + EXPECT_TRUE(divisible_analyzer.ProveDivisible(e7, e2).value_or(false)); + EXPECT_FALSE(divisible_analyzer.ProveDivisible(e7, e3).value_or(false)); +} + TEST(SingleIntervalIntSet, constant) { SingleIntervalIntSet empty_set(ir::Expr(0), ir::Expr(-1)); SingleIntervalIntSet all_set(SymbolicExprLimit::negative_inf,