Skip to content

Commit 83439e8

Browse files
authored
Merge pull request PaddlePaddle#40 from Superjomn/fix/lower
make simple lower test case works
2 parents 99da6f8 + a8167a1 commit 83439e8

File tree

15 files changed

+201
-27
lines changed

15 files changed

+201
-27
lines changed

cinn/ir/ir.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,11 @@ std::vector<const Expr *> PolyFor::expr_fields() const { return {&init, &conditi
294294
bool Var::operator==(const Var &o) const { return o->name == operator->()->name; }
295295
bool Var::operator!=(const Var &o) const { return !(*this == o); }
296296

297+
Var &Var::operator=(_Var_ *x) {
298+
*this = Var(x);
299+
return *this;
300+
}
301+
297302
} // namespace ir
298303

299304
namespace common {

cinn/ir/ir.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,8 @@ struct Var : public IrNodeRef {
278278
bool operator==(const Var& o) const;
279279
bool operator!=(const Var& o) const;
280280

281+
Var& operator=(_Var_* x);
282+
281283
const _Var_* operator->() const { return get(); }
282284
_Var_* operator->() { return get(); }
283285
const _Var_* get() const { return static_cast<const _Var_*>(ptr()); }

cinn/ir/ir_operators.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#pragma once
12
#include "cinn/ir/ir.h"
23

34
namespace cinn {
@@ -80,12 +81,12 @@ Expr operator>=(POD a, Expr b) {
8081
}
8182

8283
//--
83-
Expr operator+(Expr a, Expr b) { return Add::Make(a, b); }
84-
Expr operator-(Expr a, Expr b) { return Sub::Make(a, b); }
85-
Expr operator*(Expr a, Expr b) { return Mul::Make(a, b); }
86-
Expr operator/(Expr a, Expr b) { return Div::Make(a, b); }
87-
Expr operator&&(Expr a, Expr b) { return And::Make(Expr(a), Expr(b)); }
88-
Expr operator||(Expr a, Expr b) { return Or::Make(Expr(a), Expr(b)); }
84+
inline Expr operator+(Expr a, Expr b) { return Add::Make(a, b); }
85+
inline Expr operator-(Expr a, Expr b) { return Sub::Make(a, b); }
86+
inline Expr operator*(Expr a, Expr b) { return Mul::Make(a, b); }
87+
inline Expr operator/(Expr a, Expr b) { return Div::Make(a, b); }
88+
inline Expr operator&&(Expr a, Expr b) { return And::Make(Expr(a), Expr(b)); }
89+
inline Expr operator||(Expr a, Expr b) { return Or::Make(Expr(a), Expr(b)); }
8990

9091
} // namespace ir
9192
} // namespace cinn

cinn/lang/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ endforeach()
99
cc_test(test_compute SRCS compute_test.cc DEPS core)
1010
cc_test(test_placeholder SRCS placeholder_test.cc DEPS core)
1111
cc_test(test_tensor SRCS tensor_test.cc DEPS core)
12+
cc_test(test_lower SRCS lower_test.cc DEPS core)

cinn/lang/compute.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <vector>
66

77
#include "cinn/ir/ir.h"
8+
#include "cinn/ir/ir_operators.h"
89
#include "cinn/lang/placeholder.h"
910
#include "cinn/poly/schedule.h"
1011

cinn/lang/lower.cc

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
#include "cinn/lang/lower.h"
22

3+
#include <map>
4+
#include <set>
5+
36
#include "cinn/ir/ir_printer.h"
7+
#include "cinn/optim/remove_nested_block.h"
48
#include "cinn/optim/replace_call_with_expr.h"
59
#include "cinn/poly/ast_gen.h"
610

@@ -78,19 +82,28 @@ std::vector<LoweredFunc> Lower(const std::string& name, const std::vector<Tensor
7882
auto schedule = poly::CreateSchedule(stages);
7983

8084
// generate the expressions for each group.
81-
std::vector<Expr> block;
85+
std::vector<Expr> exprs;
86+
CHECK_GT(schedule->gened_groups().size(), 0) << "no group is generated";
8287
for (auto& group : schedule->gened_groups()) {
88+
CHECK_GT(group.nodes.size(), 0) << "group is empty";
8389
std::map<std::string, Expr> tuple_to_expr;
8490
for (auto& node : group.nodes) {
85-
auto& tensor = tensor_dic.at(node->id());
86-
tuple_to_expr[tensor->name] = tensor->body();
91+
auto& tensor = tensor_dic.at(node->id());
92+
// NOTE here just schedule the compute node.
93+
if (!tensor->is_compute_node()) continue;
94+
95+
tuple_to_expr[tensor->name] = tensor->tensor_store_expanded_body();
8796
}
8897

8998
Expr group_expr = LowerGroup(group, tuple_to_expr);
9099
VLOG(3) << "group expr: " << group_expr;
91-
block.push_back(group_expr);
100+
exprs.push_back(group_expr);
92101
}
93102

103+
Expr block = ir::Block::Make(exprs);
104+
// call passes
105+
optim::RemoveNestedBlock(&block);
106+
94107
// prepare arguments
95108
std::vector<Argument> arguments;
96109
for (auto& arg : args) {

cinn/lang/lower.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
*/
44

55
#pragma once
6+
#include <string>
7+
#include <vector>
8+
69
#include "cinn/ir/function.h"
710
#include "cinn/ir/ir.h"
811
#include "cinn/lang/module.h"

cinn/lang/lower_test.cc

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#include "cinn/lang/lower.h"
2+
3+
#include <gtest/gtest.h>
4+
5+
#include "cinn/lang/compute.h"
6+
#include "cinn/lang/placeholder.h"
7+
#include "cinn/utils/string.h"
8+
9+
namespace cinn {
10+
namespace lang {
11+
12+
TEST(lower, basic) {
13+
const int M = 100;
14+
const int N = 15;
15+
16+
Placeholder<float> A("A", {Expr(M), Expr(N)});
17+
18+
auto B = Compute(
19+
{M, N}, [=](Var i, Var j) -> Expr { return A(i, j) + 1.f; }, "B");
20+
21+
auto lower_funcs = Lower("cal_B", {A, B});
22+
23+
LOG(INFO) << "lower_size " << lower_funcs.size();
24+
25+
#define TEST_SOUTPUT(x, out) LOG(INFO) << "\n" << x; \
26+
EXPECT_EQ(utils::GetStreamCnt(x), utils::Trim(out));
27+
28+
auto out = R"ROC(
29+
{
30+
poly_for (0, (c1 <= 99), 1)
31+
{
32+
poly_for (0, (c3 <= 14), 1)
33+
{
34+
A(c1, c3)
35+
B[((c1 * 15) + c3)] = (A(c1, c3) + 1)
36+
}
37+
}
38+
}
39+
)ROC";
40+
TEST_SOUTPUT(lower_funcs.front().body, out);
41+
}
42+
43+
} // namespace lang
44+
} // namespace cinn

cinn/lang/placeholder.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ class Placeholder {
3333
Expr operator()(const std::vector<Expr> &indices) const;
3434
// @}
3535

36+
operator ir::Tensor() { return tensor_; }
37+
3638
private:
3739
ir::Tensor tensor_;
3840
};

cinn/lang/tensor.cc

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,44 @@
33
#include <cstring>
44

55
#include "cinn/common/common.h"
6+
#include "cinn/ir/ir_operators.h"
67
#include "cinn/ir/ir_visitor.h"
78
#include "cinn/ir/operation.h"
89
#include "cinn/poly/stage.h"
910

1011
namespace cinn {
1112
namespace ir {
1213

14+
namespace detail {
15+
16+
Expr ExpandTo1DIndice(const std::vector<Expr> &shape, const std::vector<Expr> &indices) {
17+
CHECK_EQ(shape.size(), indices.size());
18+
Expr res = indices.front() * shape[1];
19+
for (int i = 1; i < shape.size() - 1; i++) {
20+
res = res + indices[i] * shape[i + 1];
21+
}
22+
if (shape.size() > 1) res = res + indices.back();
23+
return res;
24+
}
25+
26+
Expr ExpandTo1DIndice(const std::vector<int> &shape, const std::vector<Expr> &indices) {
27+
std::vector<Expr> shape_;
28+
for (int v : shape) shape_.push_back(Expr(v));
29+
return ExpandTo1DIndice(shape, indices);
30+
}
31+
32+
} // namespace detail
33+
1334
Tensor _Tensor_::Make(const std::string &name, const std::vector<Expr> &shape, FunctionRef fn) {
35+
CHECK(!shape.empty()) << "Tensor shape is set empty";
36+
CHECK(!name.empty()) << "Tensor name is set empty";
1437
auto n = make_shared<_Tensor_>();
1538
n->name = name;
1639
n->shape = shape;
1740
n->operaion = fn;
1841
n->InitStage();
42+
n->InitAxis();
43+
n->SetDefaultBindedBuffer();
1944
return Tensor(n);
2045
}
2146

@@ -26,8 +51,13 @@ Tensor _Tensor_::Make(const std::string &name,
2651
Type dtype,
2752
const std::map<std::string, IrNodeRef> &attrs,
2853
const std::vector<Expr> &body) {
54+
CHECK(!shape.empty()) << "Tensor shape is set empty";
55+
CHECK(!name.empty()) << "Tensor name is set empty";
56+
2957
auto op = ComputeOp::Make(name, tag, attrs, axis, body, shape);
3058
auto *compute_op = const_cast<ComputeOp *>(op->As<ComputeOp>());
59+
60+
CHECK_EQ(axis.size(), shape.size()) << "axis not match the dimension in shape";
3161
compute_op->axis = axis;
3262

3363
auto n = make_shared<_Tensor_>();
@@ -36,6 +66,7 @@ Tensor _Tensor_::Make(const std::string &name,
3666
n->shape = shape;
3767
n->set_type(dtype);
3868
n->InitStage();
69+
n->SetDefaultBindedBuffer();
3970
return Tensor(n);
4071
}
4172

@@ -76,6 +107,12 @@ void _Tensor_::InitStage() {
76107
}
77108
}
78109

110+
void _Tensor_::InitAxis() {
111+
CHECK(!shape.empty());
112+
CHECK(axis.empty()) << "duplicate init axis";
113+
axis = common::GenDefaultAxis(shape.size());
114+
}
115+
79116
isl::set _Tensor_::GenerateIslDomain() {
80117
CHECK(!shape.empty()) << "shape should be set";
81118
std::vector<poly::Dim> dims;
@@ -138,5 +175,12 @@ Expr _Tensor_::body() const {
138175
NOT_IMPLEMENTED;
139176
}
140177

178+
Expr _Tensor_::tensor_store_expanded_body() const {
179+
CHECK(!is_placeholder_node()) << "placeholder should not expand store";
180+
std::vector<Expr> axis_;
181+
for (auto &a : axis) axis_.push_back(Expr(a));
182+
return ir::Store::Make(buffer_var, body(), detail::ExpandTo1DIndice(shape, axis_));
183+
}
184+
141185
} // namespace ir
142186
} // namespace cinn

0 commit comments

Comments
 (0)