|
| 1 | +#include "cinn/poly/ast_gen.h" |
1 | 2 |
|
| 3 | +namespace cinn { |
| 4 | +namespace poly { |
| 5 | + |
| 6 | +isl::ast_node AstGen::operator()(const std::vector<Element> &elements, const Scheduler &scheduler) { |
| 7 | + // Collect domains. |
| 8 | + auto sets = utils::Map<std::vector<Element>, isl::set>(elements, [](const Element &e) { return e.domain(); }); |
| 9 | + isl::union_set domain = SetsToUnionSet(sets); |
| 10 | + |
| 11 | + isl::ctx ctx = elements.front().domain().ctx(); |
| 12 | + |
| 13 | + // Collect schedule from scheduler. |
| 14 | + auto schedules = scheduler.BuildSchedule(); |
| 15 | + std::vector<isl::map> maps; |
| 16 | + for (auto &ele : elements) { |
| 17 | + auto it = schedules.find(ele.id()); |
| 18 | + CHECK(it != std::end(schedules)); |
| 19 | + maps.push_back(it->second); |
| 20 | + } |
| 21 | + auto schedule = MapsToUnionMap(maps); |
| 22 | + |
| 23 | + // Build it. |
| 24 | + auto build = isl::ast_build::from_context(context_); |
| 25 | + // Set iterators. |
| 26 | + if (!iterator_names_.empty()) { |
| 27 | + auto iterator_names = scheduler.WrapIteratorNames(iterator_names_); |
| 28 | + isl::id_list ids = isl::manage(isl_id_list_alloc(ctx.get(), iterator_names.size())); |
| 29 | + for (int i = 0; i < iterator_names.size(); i++) { |
| 30 | + ids = isl::manage(isl_id_list_add(ids.release(), isl_id_alloc(ctx.get(), iterator_names[i].c_str(), nullptr))); |
| 31 | + } |
| 32 | + build = isl::manage(isl_ast_build_set_iterators(build.release(), ids.release())); |
| 33 | + } |
| 34 | + |
| 35 | + auto ast = build.node_from_schedule_map(schedule.intersect_domain(domain)); |
| 36 | + VLOG(2) << "\n" << isl_ast_node_to_C_str(ast.get()); |
| 37 | + return ast; |
| 38 | +} |
| 39 | + |
| 40 | +AstGen &AstGen::SetIteratorNames(const std::vector<std::string> &names) { |
| 41 | + iterator_names_ = names; |
| 42 | + return *this; |
| 43 | +} |
| 44 | + |
| 45 | +} // namespace poly |
| 46 | +} // namespace cinn |
0 commit comments