Skip to content

Commit b20f50f

Browse files
authored
Merge pull request PaddlePaddle#20 from Superjomn/fea/add-ast-gen
add ast build
2 parents ea6cdb5 + bebbacb commit b20f50f

File tree

8 files changed

+144
-0
lines changed

8 files changed

+144
-0
lines changed

cinn/poly/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ cc_library(poly SRCS
1010

1111
cc_test(test_poly_element SRCS element_test.cc DEPS poly)
1212
cc_test(test_schedule SRCS schedule_test.cc DEPS poly)
13+
cc_test(test_ast_gen SRCS ast_gen_test.cc DEPS poly)

cinn/poly/ast_gen.cc

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,46 @@
1+
#include "cinn/poly/ast_gen.h"
12

3+
namespace cinn {
4+
namespace poly {
5+
6+
isl::ast_node AstGen::operator()(const std::vector<Element> &elements, const Scheduler &scheduler) {
7+
// Collect domains.
8+
auto sets = utils::Map<std::vector<Element>, isl::set>(elements, [](const Element &e) { return e.domain(); });
9+
isl::union_set domain = SetsToUnionSet(sets);
10+
11+
isl::ctx ctx = elements.front().domain().ctx();
12+
13+
// Collect schedule from scheduler.
14+
auto schedules = scheduler.BuildSchedule();
15+
std::vector<isl::map> maps;
16+
for (auto &ele : elements) {
17+
auto it = schedules.find(ele.id());
18+
CHECK(it != std::end(schedules));
19+
maps.push_back(it->second);
20+
}
21+
auto schedule = MapsToUnionMap(maps);
22+
23+
// Build it.
24+
auto build = isl::ast_build::from_context(context_);
25+
// Set iterators.
26+
if (!iterator_names_.empty()) {
27+
auto iterator_names = scheduler.WrapIteratorNames(iterator_names_);
28+
isl::id_list ids = isl::manage(isl_id_list_alloc(ctx.get(), iterator_names.size()));
29+
for (int i = 0; i < iterator_names.size(); i++) {
30+
ids = isl::manage(isl_id_list_add(ids.release(), isl_id_alloc(ctx.get(), iterator_names[i].c_str(), nullptr)));
31+
}
32+
build = isl::manage(isl_ast_build_set_iterators(build.release(), ids.release()));
33+
}
34+
35+
auto ast = build.node_from_schedule_map(schedule.intersect_domain(domain));
36+
VLOG(2) << "\n" << isl_ast_node_to_C_str(ast.get());
37+
return ast;
38+
}
39+
40+
AstGen &AstGen::SetIteratorNames(const std::vector<std::string> &names) {
41+
iterator_names_ = names;
42+
return *this;
43+
}
44+
45+
} // namespace poly
46+
} // namespace cinn

cinn/poly/ast_gen.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,30 @@
11
#pragma once
2+
#include <isl/cpp.h>
3+
#include "cinn/poly/element.h"
4+
#include "cinn/poly/isl_utils.h"
5+
#include "cinn/poly/schedule.h"
6+
#include "cinn/utils/functional.h"
7+
8+
namespace cinn {
9+
namespace poly {
10+
11+
class AstGen {
12+
public:
13+
AstGen(const isl::set& context) : context_(context) {}
14+
15+
/**
16+
* Set forloop iterator names.
17+
* @param names
18+
* @return AstGen itself.
19+
*/
20+
AstGen& SetIteratorNames(const std::vector<std::string>& names);
21+
22+
isl::ast_node operator()(const std::vector<Element>& elements, const Scheduler& scheduler);
23+
24+
private:
25+
isl::set context_;
26+
std::vector<std::string> iterator_names_;
27+
};
28+
29+
} // namespace poly
30+
} // namespace cinn

cinn/poly/ast_gen_test.cc

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#include "cinn/poly/ast_gen.h"
2+
#include <gtest/gtest.h>
3+
4+
namespace cinn {
5+
namespace poly {
6+
7+
TEST(ast_gen, basic) {
8+
isl::ctx ctx(isl_ctx_alloc());
9+
Element A(isl::set(ctx, "{ A[i,j,k]: 0<i,j,k<100 }"));
10+
Element B(isl::set(ctx, "{ B[i,j,k]: 0<i,j,k<100 }"));
11+
12+
Scheduler scheduler;
13+
scheduler.RegisterElement(A);
14+
scheduler.RegisterElement(B);
15+
scheduler.After(A, B, 2);
16+
17+
AstGen gen(isl::set(ctx, "{:}"));
18+
gen.SetIteratorNames({"i", "j", "k"});
19+
gen({A, B}, scheduler);
20+
}
21+
22+
} // namespace poly
23+
} // namespace cinn

cinn/poly/isl_utils.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,5 +39,23 @@ void SetDimNames(isl::set *set, const std::vector<std::string> &names) {
3939
}
4040
}
4141

42+
isl::union_map MapsToUnionMap(const std::vector<isl::map> &maps) {
43+
CHECK(!maps.empty());
44+
isl::union_map umap = isl::manage(isl_union_map_from_map(maps.front().copy()));
45+
for (int i = 1; i < maps.size(); i++) {
46+
umap = isl::manage(isl_union_map_add_map(umap.release(), maps[i].copy()));
47+
}
48+
return umap;
49+
}
50+
51+
isl::union_set SetsToUnionSet(const std::vector<isl::set> &sets) {
52+
CHECK(!sets.empty());
53+
isl::union_set uset = isl::manage(isl_union_set_from_set(sets.front().copy()));
54+
for (int i = 1; i < sets.size(); i++) {
55+
uset = isl::manage(isl_union_set_add_set(uset.release(), sets[i].copy()));
56+
}
57+
return uset;
58+
}
59+
4260
} // namespace poly
4361
} // namespace cinn

cinn/poly/isl_utils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,9 @@ std::vector<std::string> GetDimNames(const isl::map& x, isl_dim_type dim_type);
1616
void SetDimNames(isl::set* set, const std::vector<std::string>& names);
1717
void SetDimNames(isl::map* map, isl_dim_type dim_type, const std::vector<std::string>& names);
1818

19+
//! Convert a list of isl::map to isl::union_map
20+
isl::union_map MapsToUnionMap(const std::vector<isl::map>& maps);
21+
isl::union_set SetsToUnionSet(const std::vector<isl::set>& sets);
22+
1923
} // namespace poly
2024
} // namespace cinn

cinn/poly/schedule.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,5 +150,15 @@ std::map<std::string, isl::map> Scheduler::BuildSchedule() const {
150150
return res;
151151
}
152152

153+
std::vector<std::string> Scheduler::WrapIteratorNames(const std::vector<std::string> &names) const {
154+
CHECK_EQ(names.size(), space_size());
155+
std::vector<std::string> res;
156+
for (int i = 0; i < space_size(); i++) {
157+
res.push_back(""); // fake name for time space.
158+
res.push_back(names[i]); // name for the corresponding iterator.
159+
}
160+
return res;
161+
}
162+
153163
} // namespace poly
154164
} // namespace cinn

cinn/poly/schedule.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,11 @@ class Scheduler {
8989
*/
9090
void FinalizeRegistration();
9191

92+
/**
93+
* Tell whether the registration is finalized.
94+
*/
95+
bool finalized() const { return registration_finalized_; }
96+
9297
/**
9398
* Mark this should schedule after another.
9499
*
@@ -108,6 +113,15 @@ class Scheduler {
108113
*/
109114
std::map<std::string, isl::map> BuildSchedule() const;
110115

116+
/**
117+
* Wrap the iterator names with time space.
118+
* @param names the original iterator names.
119+
* @return the iterator names with time space included.
120+
*/
121+
std::vector<std::string> WrapIteratorNames(const std::vector<std::string> &names) const;
122+
123+
int space_size() const { return space_size_; }
124+
111125
private:
112126
/**
113127
* The polyhedral schedule, any schedule is performed on it.

0 commit comments

Comments
 (0)