Skip to content

Commit ee3d2fc

Browse files
authored
Add CanProveDivisible for symbolic calculation (#60572)
* add CanProveDivisible for symbolic calculation * delete extra cout for debug * fix according to some comments
1 parent ed6f32d commit ee3d2fc

File tree

3 files changed

+163
-4
lines changed

3 files changed

+163
-4
lines changed

paddle/cinn/common/integer_set.cc

Lines changed: 117 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
// limitations under the License.
1414

1515
#include "paddle/cinn/common/integer_set.h"
16+
17+
#include "paddle/cinn/common/arithmatic.h"
1618
#include "paddle/cinn/ir/ir_mutator.h"
1719
#include "paddle/cinn/ir/op/ir_operators.h"
1820
#include "paddle/cinn/ir/utils/ir_copy.h"
@@ -164,11 +166,115 @@ std::optional<bool> SymbolicExprAnalyzer::ProveLT(const ir::Expr& lhs,
164166
return ProveGT(rhs, lhs);
165167
}
166168

169+
// Tell whether lhs can be divisible by rhs, lhs must be a pure math expression
170+
// and rhs must be a var
171+
std::optional<bool> SymbolicExprAnalyzer::ProveDivisible(
172+
const ir::Expr& lhs, const ir::Expr& rhs) const {
173+
CHECK(rhs.is_var()) << "Rhs in ProveDivisible must be a var temporarily!\n";
174+
CHECK(lhs.defined());
175+
CHECK(rhs.defined());
176+
CHECK(cinn::common::IsPureMath(lhs));
177+
178+
ir::Expr lhs_copy = ir::ir_utils::IRCopy(lhs);
179+
if (cinn::common::is_zero(lhs_copy)) return true;
180+
181+
auto OptionalAnd = [](const std::optional<bool>& lhs,
182+
const std::optional<bool>& rhs) -> std::optional<bool> {
183+
if (lhs.has_value() && rhs.has_value()) {
184+
return lhs.value() && rhs.value();
185+
} else {
186+
return std::nullopt;
187+
}
188+
};
189+
auto OptionalOr = [](const std::optional<bool>& lhs,
190+
const std::optional<bool>& rhs) -> std::optional<bool> {
191+
if (lhs.has_value() && rhs.has_value()) {
192+
return lhs.value() || rhs.value();
193+
} else if ((!lhs.has_value()) && (!rhs.has_value())) {
194+
return std::nullopt;
195+
} else if (lhs.has_value() && (!rhs.has_value())) {
196+
return lhs.value() ? std::optional<bool>(lhs.value())
197+
: std::optional<bool>(std::nullopt);
198+
} else {
199+
return rhs.value() ? std::optional<bool>(rhs.value())
200+
: std::optional<bool>(std::nullopt);
201+
}
202+
};
203+
204+
std::vector<ir::Expr> ops{};
205+
std::optional<bool> res = std::nullopt;
206+
ir::Expr zero(0);
207+
ir::Expr tmp_expr;
208+
209+
auto is_ge = ProveGE(lhs, rhs);
210+
211+
switch (lhs.node_type()) {
212+
case cinn::ir::IrNodeTy::_Var_:
213+
return ProveEQ(lhs, rhs);
214+
case cinn::ir::IrNodeTy::IntImm:
215+
return false;
216+
case cinn::ir::IrNodeTy::Sum:
217+
res = true;
218+
ops = lhs.As<ir::Sum>()->operands();
219+
CHECK(!ops.empty());
220+
std::for_each(ops.begin(), ops.end(), [&](const ir::Expr& expr) {
221+
res = OptionalAnd(res, this->ProveDivisible(expr, rhs));
222+
});
223+
res = OptionalAnd(res, is_ge);
224+
return res;
225+
case cinn::ir::IrNodeTy::Product:
226+
res = false;
227+
ops = lhs.As<ir::Product>()->operands();
228+
CHECK(!ops.empty());
229+
std::for_each(ops.begin(), ops.end(), [&](const ir::Expr& expr) {
230+
res = OptionalOr(res, this->ProveDivisible(expr, rhs));
231+
if (res.has_value() && res.value()) return;
232+
});
233+
res = OptionalAnd(res, is_ge);
234+
return res;
235+
case cinn::ir::IrNodeTy::FracOp:
236+
tmp_expr = cinn::common::AutoSimplify(lhs);
237+
if (tmp_expr.node_type() == cinn::ir::IrNodeTy::FracOp)
238+
return std::nullopt;
239+
return OptionalAnd(ProveDivisible(tmp_expr, rhs), is_ge);
240+
case cinn::ir::IrNodeTy::FloatImm:
241+
return false;
242+
case cinn::ir::IrNodeTy::Add:
243+
return OptionalAnd(
244+
OptionalAnd(ProveDivisible(lhs.As<ir::Add>()->a(), rhs),
245+
ProveDivisible(lhs.As<ir::Add>()->b(), rhs)),
246+
is_ge);
247+
case cinn::ir::IrNodeTy::Sub:
248+
return OptionalAnd(
249+
OptionalAnd(ProveDivisible(lhs.As<ir::Sub>()->a(), rhs),
250+
ProveDivisible(lhs.As<ir::Sub>()->b(), rhs)),
251+
is_ge);
252+
case cinn::ir::IrNodeTy::Div:
253+
tmp_expr = cinn::common::AutoSimplify(lhs);
254+
if (tmp_expr.node_type() == cinn::ir::IrNodeTy::Div) return std::nullopt;
255+
return OptionalAnd(ProveDivisible(tmp_expr, rhs), is_ge);
256+
case cinn::ir::IrNodeTy::Mul:
257+
return OptionalAnd(
258+
OptionalOr(ProveDivisible(lhs.As<ir::Mul>()->a(), rhs),
259+
ProveDivisible(lhs.As<ir::Mul>()->b(), rhs)),
260+
is_ge);
261+
case cinn::ir::IrNodeTy::Mod:
262+
return false;
263+
case cinn::ir::IrNodeTy::Minus:
264+
return ProveDivisible(lhs.As<ir::Minus>()->v(), rhs);
265+
default:
266+
LOG(FATAL) << "Not supported yet!";
267+
break;
268+
}
269+
}
270+
167271
class BoundReplacer : public ir::IRMutator<> {
168272
public:
169273
explicit BoundReplacer(const cas_intervals_t& var_intervals,
170274
bool is_lower_bound)
171-
: var_intervals_(var_intervals), sign_(is_lower_bound) {}
275+
: var_intervals_(var_intervals),
276+
sign_(is_lower_bound),
277+
var_visited_({}) {}
172278

173279
void operator()(ir::Expr* expr) { IRMutator::Visit(expr, expr); }
174280

@@ -183,10 +289,16 @@ class BoundReplacer : public ir::IRMutator<> {
183289
upper_bound =
184290
interval.e_r.defined() ? interval.e_r : ir::Expr(interval.r);
185291
}
186-
if (sign_) {
187-
*op = ir::ir_utils::IRCopy(lower_bound);
292+
if (!var_visited_.count(var->name)) {
293+
if (sign_) {
294+
*op = ir::ir_utils::IRCopy(lower_bound);
295+
var_visited_.insert({var->name, lower_bound});
296+
} else {
297+
*op = ir::ir_utils::IRCopy(upper_bound);
298+
var_visited_.insert({var->name, upper_bound});
299+
}
188300
} else {
189-
*op = ir::ir_utils::IRCopy(upper_bound);
301+
*op = ir::ir_utils::IRCopy(var_visited_.at(var->name));
190302
}
191303
}
192304

@@ -248,6 +360,7 @@ class BoundReplacer : public ir::IRMutator<> {
248360

249361
private:
250362
const cas_intervals_t& var_intervals_;
363+
std::unordered_map<std::string, ir::Expr> var_visited_;
251364
// Determine replacing with upper or lower bound,
252365
// True means lower bound and False means upper bound.
253366
bool sign_;

paddle/cinn/common/integer_set.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ class SymbolicExprAnalyzer {
4141
std::optional<bool> ProveLE(const ir::Expr& lhs, const ir::Expr& rhs) const;
4242
std::optional<bool> ProveGT(const ir::Expr& lhs, const ir::Expr& rhs) const;
4343
std::optional<bool> ProveLT(const ir::Expr& lhs, const ir::Expr& rhs) const;
44+
std::optional<bool> ProveDivisible(const ir::Expr& lhs,
45+
const ir::Expr& rhs) const;
4446

4547
ir::Expr LowerBound(const ir::Expr& expr) const;
4648
ir::Expr UpperBound(const ir::Expr& expr) const;

paddle/cinn/common/integer_set_test.cc

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,50 @@ TEST_F(TestSymbolicExprAnalyzer, compare) {
136136
analyzer.Prove(e3 < e4).value());
137137
}
138138

139+
TEST_F(TestSymbolicExprAnalyzer, Divisible) {
140+
auto x = ir::Var(ir::Expr(1), ir::Expr(7), "x");
141+
auto y = ir::Var(ir::Expr(1), ir::Expr(15), "y");
142+
auto S = ir::Var(ir::Expr(16), ir::Expr(256), "S");
143+
144+
cas_intervals_t divisible_var_intervals = {
145+
{"x", CasInterval(x->lower_bound, x->upper_bound)},
146+
{"y", CasInterval(y->lower_bound, y->upper_bound)},
147+
{"S", CasInterval(S->lower_bound, S->upper_bound)},
148+
};
149+
SymbolicExprAnalyzer divisible_analyzer{divisible_var_intervals};
150+
151+
// case 1
152+
ir::Expr e1 = 4 * x + 2 * y * x;
153+
ir::Expr e2 = x;
154+
ir::Expr e3 = y;
155+
156+
EXPECT_TRUE(divisible_analyzer.ProveDivisible(e1, e2).value_or(false));
157+
EXPECT_FALSE(divisible_analyzer.ProveDivisible(e1, e3).value_or(false));
158+
159+
// case 2
160+
ir::Expr e4 = y + y * x + 4 * y - x * y;
161+
162+
EXPECT_TRUE(divisible_analyzer.ProveDivisible(e4, e3).value_or(false));
163+
EXPECT_FALSE(divisible_analyzer.ProveDivisible(e4, e2).value_or(false));
164+
165+
// case 3
166+
ir::Expr e5 = x / y + x + y;
167+
168+
EXPECT_FALSE(divisible_analyzer.ProveDivisible(e5, e3).value_or(false));
169+
EXPECT_FALSE(divisible_analyzer.ProveDivisible(e5, e2).value_or(false));
170+
171+
// case 4
172+
ir::Expr e6 = S * x / 4 + x * y;
173+
174+
EXPECT_FALSE(divisible_analyzer.ProveDivisible(e6, e2).value_or(false));
175+
EXPECT_FALSE(divisible_analyzer.ProveDivisible(e6, e3).value_or(false));
176+
177+
ir::Expr e7 = 16 * x / 4 + x * y;
178+
179+
EXPECT_TRUE(divisible_analyzer.ProveDivisible(e7, e2).value_or(false));
180+
EXPECT_FALSE(divisible_analyzer.ProveDivisible(e7, e3).value_or(false));
181+
}
182+
139183
TEST(SingleIntervalIntSet, constant) {
140184
SingleIntervalIntSet empty_set(ir::Expr(0), ir::Expr(-1));
141185
SingleIntervalIntSet all_set(SymbolicExprLimit::negative_inf,

0 commit comments

Comments
 (0)