|
12 | 12 | // See the License for the specific language governing permissions and |
13 | 13 | // limitations under the License. |
14 | 14 |
|
15 | | -#include "paddle/cinn/common/macros.h" |
16 | 15 | #include "paddle/cinn/ir/schedule/impl/ir_schedule.h" |
17 | 16 |
|
| 17 | +#include "paddle/cinn/common/integer_set.h" |
| 18 | +#include "paddle/cinn/common/macros.h" |
| 19 | + |
18 | 20 | /** \brief A macro that guards the beginning of each implementation of schedule |
19 | 21 | */ |
20 | 22 | #define CINN_IR_SCHEDULE_BEGIN() try { |
@@ -157,6 +159,63 @@ std::vector<Expr> DyScheduleImpl::Split(const Expr& loop, |
157 | 159 | return splited_loops; |
158 | 160 | } |
159 | 161 |
|
| 162 | +// TODO(@LiuYang): now -1 can't exsit in factors, |
| 163 | +std::vector<Expr> DyScheduleImpl::Split(const Expr& loop, |
| 164 | + const std::vector<Expr>& factors) { |
| 165 | + CHECK(loop.As<ir::For>()) |
| 166 | + << "Expr param of Split must be For node! Please check."; |
| 167 | + auto* for_node = loop.As<ir::For>(); |
| 168 | + CHECK(common::is_zero(for_node->min)) |
| 169 | + << "The For node must start with 0! Please check."; |
| 170 | + CHECK(!factors.empty()) |
| 171 | + << "The factors param of Split should not be empty! Please check."; |
| 172 | + CHECK(!loop.As<ir::For>()->extent.is_constant()) |
| 173 | + << "Can't Split a loop with constant extent but with variable in " |
| 174 | + "factors!"; |
| 175 | + Expr tot_extent = for_node->extent; |
| 176 | + |
| 177 | + VLOG(3) << "Try Split loop from (" << for_node->loop_var->name << ", 0, " |
| 178 | + << tot_extent << ") to (" << cinn::utils::Join(factors, ", ") |
| 179 | + << ") at loop:\n" |
| 180 | + << loop; |
| 181 | + |
| 182 | + std::vector<Expr> process_factors(factors); |
| 183 | + Expr prod_size(1); |
| 184 | + for (auto factor : factors) prod_size = prod_size * Expr(factor); |
| 185 | + cinn::common::SymbolicExprAnalyzer analyzer({}); |
| 186 | + CHECK(analyzer.ProveEQ(tot_extent, prod_size).value_or(false)) |
| 187 | + << "Product of factors can't be proved to be equal to the extent of " |
| 188 | + "current for loop!"; |
| 189 | + |
| 190 | + std::vector<Var> new_loop_vars; |
| 191 | + Expr substitute_value(0); |
| 192 | + for (int i = 0; i < process_factors.size(); ++i) { |
| 193 | + Var temp_var(common::UniqName(for_node->loop_var->name)); |
| 194 | + substitute_value = Expr(temp_var) + substitute_value * process_factors[i]; |
| 195 | + new_loop_vars.push_back(temp_var); |
| 196 | + } |
| 197 | + substitute_value = cinn::common::AutoSimplify(substitute_value); |
| 198 | + Expr new_node = ir::ir_utils::IRCopy(for_node->body); |
| 199 | + ReplaceExpr(&new_node, {for_node->loop_var}, {substitute_value}); |
| 200 | + std::vector<Expr> splited_loops; |
| 201 | + splited_loops.resize(process_factors.size()); |
| 202 | + |
| 203 | + for (int i = process_factors.size() - 1; i >= 0; i--) { |
| 204 | + if (!new_node.As<ir::Block>()) new_node = Block::Make({new_node}); |
| 205 | + new_node = For::Make(new_loop_vars[i], |
| 206 | + Expr(0), |
| 207 | + process_factors[i], |
| 208 | + for_node->for_type(), |
| 209 | + for_node->device_api, |
| 210 | + new_node); |
| 211 | + splited_loops[i] = new_node; |
| 212 | + } |
| 213 | + |
| 214 | + this->Replace(loop, new_node); |
| 215 | + VLOG(3) << "After Split, ir is:\n" << splited_loops.at(0); |
| 216 | + return splited_loops; |
| 217 | +} |
| 218 | + |
160 | 219 | Expr DyScheduleImpl::Fuse(const std::vector<Expr>& loops) { |
161 | 220 | VLOG(3) << "Tring to fuse:\n" << cinn::utils::Join(loops, "\n"); |
162 | 221 | std::vector<const ir::For*> for_nodes; |
@@ -370,6 +429,12 @@ std::vector<Expr> StScheduleImpl::Split(const Expr& loop, |
370 | 429 | return splited_loops; |
371 | 430 | } |
372 | 431 |
|
| 432 | +std::vector<Expr> StScheduleImpl::Split(const Expr& loop, |
| 433 | + const std::vector<Expr>& factors) { |
| 434 | + CHECK(false) << "Static shape schedule don't support Split with some " |
| 435 | + "variables in factors"; |
| 436 | +} |
| 437 | + |
373 | 438 | Expr StScheduleImpl::Fuse(const std::vector<Expr>& loops) { |
374 | 439 | VLOG(3) << "Tring to fuse:\n" << cinn::utils::Join(loops, "\n"); |
375 | 440 | std::vector<const ir::For*> for_nodes; |
|
0 commit comments