1313// limitations under the License.
1414
1515#include " paddle/cinn/common/integer_set.h"
16+
17+ #include " paddle/cinn/common/arithmatic.h"
1618#include " paddle/cinn/ir/ir_mutator.h"
1719#include " paddle/cinn/ir/op/ir_operators.h"
1820#include " paddle/cinn/ir/utils/ir_copy.h"
@@ -164,11 +166,115 @@ std::optional<bool> SymbolicExprAnalyzer::ProveLT(const ir::Expr& lhs,
164166 return ProveGT (rhs, lhs);
165167}
166168
169+ // Tell whether lhs can be divisible by rhs, lhs must be a pure math expression
170+ // and rhs must be a var
171+ std::optional<bool > SymbolicExprAnalyzer::ProveDivisible (
172+ const ir::Expr& lhs, const ir::Expr& rhs) const {
173+ CHECK (rhs.is_var ()) << " Rhs in ProveDivisible must be a var temporarily!\n " ;
174+ CHECK (lhs.defined ());
175+ CHECK (rhs.defined ());
176+ CHECK (cinn::common::IsPureMath (lhs));
177+
178+ ir::Expr lhs_copy = ir::ir_utils::IRCopy (lhs);
179+ if (cinn::common::is_zero (lhs_copy)) return true ;
180+
181+ auto OptionalAnd = [](const std::optional<bool >& lhs,
182+ const std::optional<bool >& rhs) -> std::optional<bool > {
183+ if (lhs.has_value () && rhs.has_value ()) {
184+ return lhs.value () && rhs.value ();
185+ } else {
186+ return std::nullopt ;
187+ }
188+ };
189+ auto OptionalOr = [](const std::optional<bool >& lhs,
190+ const std::optional<bool >& rhs) -> std::optional<bool > {
191+ if (lhs.has_value () && rhs.has_value ()) {
192+ return lhs.value () || rhs.value ();
193+ } else if ((!lhs.has_value ()) && (!rhs.has_value ())) {
194+ return std::nullopt ;
195+ } else if (lhs.has_value () && (!rhs.has_value ())) {
196+ return lhs.value () ? std::optional<bool >(lhs.value ())
197+ : std::optional<bool >(std::nullopt );
198+ } else {
199+ return rhs.value () ? std::optional<bool >(rhs.value ())
200+ : std::optional<bool >(std::nullopt );
201+ }
202+ };
203+
204+ std::vector<ir::Expr> ops{};
205+ std::optional<bool > res = std::nullopt ;
206+ ir::Expr zero (0 );
207+ ir::Expr tmp_expr;
208+
209+ auto is_ge = ProveGE (lhs, rhs);
210+
211+ switch (lhs.node_type ()) {
212+ case cinn::ir::IrNodeTy::_Var_:
213+ return ProveEQ (lhs, rhs);
214+ case cinn::ir::IrNodeTy::IntImm:
215+ return false ;
216+ case cinn::ir::IrNodeTy::Sum:
217+ res = true ;
218+ ops = lhs.As <ir::Sum>()->operands ();
219+ CHECK (!ops.empty ());
220+ std::for_each (ops.begin (), ops.end (), [&](const ir::Expr& expr) {
221+ res = OptionalAnd (res, this ->ProveDivisible (expr, rhs));
222+ });
223+ res = OptionalAnd (res, is_ge);
224+ return res;
225+ case cinn::ir::IrNodeTy::Product:
226+ res = false ;
227+ ops = lhs.As <ir::Product>()->operands ();
228+ CHECK (!ops.empty ());
229+ std::for_each (ops.begin (), ops.end (), [&](const ir::Expr& expr) {
230+ res = OptionalOr (res, this ->ProveDivisible (expr, rhs));
231+ if (res.has_value () && res.value ()) return ;
232+ });
233+ res = OptionalAnd (res, is_ge);
234+ return res;
235+ case cinn::ir::IrNodeTy::FracOp:
236+ tmp_expr = cinn::common::AutoSimplify (lhs);
237+ if (tmp_expr.node_type () == cinn::ir::IrNodeTy::FracOp)
238+ return std::nullopt ;
239+ return OptionalAnd (ProveDivisible (tmp_expr, rhs), is_ge);
240+ case cinn::ir::IrNodeTy::FloatImm:
241+ return false ;
242+ case cinn::ir::IrNodeTy::Add:
243+ return OptionalAnd (
244+ OptionalAnd (ProveDivisible (lhs.As <ir::Add>()->a (), rhs),
245+ ProveDivisible (lhs.As <ir::Add>()->b (), rhs)),
246+ is_ge);
247+ case cinn::ir::IrNodeTy::Sub:
248+ return OptionalAnd (
249+ OptionalAnd (ProveDivisible (lhs.As <ir::Sub>()->a (), rhs),
250+ ProveDivisible (lhs.As <ir::Sub>()->b (), rhs)),
251+ is_ge);
252+ case cinn::ir::IrNodeTy::Div:
253+ tmp_expr = cinn::common::AutoSimplify (lhs);
254+ if (tmp_expr.node_type () == cinn::ir::IrNodeTy::Div) return std::nullopt ;
255+ return OptionalAnd (ProveDivisible (tmp_expr, rhs), is_ge);
256+ case cinn::ir::IrNodeTy::Mul:
257+ return OptionalAnd (
258+ OptionalOr (ProveDivisible (lhs.As <ir::Mul>()->a (), rhs),
259+ ProveDivisible (lhs.As <ir::Mul>()->b (), rhs)),
260+ is_ge);
261+ case cinn::ir::IrNodeTy::Mod:
262+ return false ;
263+ case cinn::ir::IrNodeTy::Minus:
264+ return ProveDivisible (lhs.As <ir::Minus>()->v (), rhs);
265+ default :
266+ LOG (FATAL) << " Not supported yet!" ;
267+ break ;
268+ }
269+ }
270+
167271class BoundReplacer : public ir ::IRMutator<> {
168272 public:
169273 explicit BoundReplacer (const cas_intervals_t & var_intervals,
170274 bool is_lower_bound)
171- : var_intervals_(var_intervals), sign_(is_lower_bound) {}
275+ : var_intervals_(var_intervals),
276+ sign_(is_lower_bound),
277+ var_visited_({}) {}
172278
173279 void operator ()(ir::Expr* expr) { IRMutator::Visit (expr, expr); }
174280
@@ -183,10 +289,16 @@ class BoundReplacer : public ir::IRMutator<> {
183289 upper_bound =
184290 interval.e_r .defined () ? interval.e_r : ir::Expr (interval.r );
185291 }
186- if (sign_) {
187- *op = ir::ir_utils::IRCopy (lower_bound);
292+ if (!var_visited_.count (var->name )) {
293+ if (sign_) {
294+ *op = ir::ir_utils::IRCopy (lower_bound);
295+ var_visited_.insert ({var->name , lower_bound});
296+ } else {
297+ *op = ir::ir_utils::IRCopy (upper_bound);
298+ var_visited_.insert ({var->name , upper_bound});
299+ }
188300 } else {
189- *op = ir::ir_utils::IRCopy (upper_bound );
301+ *op = ir::ir_utils::IRCopy (var_visited_. at (var-> name ) );
190302 }
191303 }
192304
@@ -248,6 +360,7 @@ class BoundReplacer : public ir::IRMutator<> {
248360
249361 private:
250362 const cas_intervals_t & var_intervals_;
363+ std::unordered_map<std::string, ir::Expr> var_visited_;
251364 // Determine replacing with upper or lower bound,
252365 // True means lower bound and False means upper bound.
253366 bool sign_;
0 commit comments