Skip to content

Commit 18a20ce

Browse files
authored
Merge pull request PaddlePaddle#48 from Superjomn/fea/make-lowered_func-ir
fea/make lowered func ir
2 parents d9a1874 + 785554f commit 18a20ce

File tree

18 files changed

+261
-161
lines changed

18 files changed

+261
-161
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@
55
cmake-build*
66
build*
77
.idea*
8+
*.html

cinn/backends/codegen_c.cc

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
11
#include "cinn/backends/codegen_c.h"
22

3+
#include "cinn/ir/lowered_func.h"
4+
35
namespace cinn {
46
namespace backends {
57

68
CodeGenC::CodeGenC(std::ostream &os, Target target) : ir::IrPrinter(os), target_(target) {}
79

810
void CodeGenC::Compile(const lang::Module &module) {}
9-
void CodeGenC::Compile(const lang::LoweredFunc &function) {
10-
os() << "void " << function.name;
11+
void CodeGenC::Compile(const ir::LoweredFunc &function) {
12+
os() << "void " << function->name;
1113

1214
// output arguments
1315
os() << "(";
1416

15-
auto print_arg = [&](const lang::Argument &arg) {
17+
auto print_arg = [&](const ir::Argument &arg) {
1618
if (arg.is_buffer()) {
1719
os() << "struct cinn_buffer_t *";
1820
} else if (arg.is_scalar()) {
@@ -22,20 +24,20 @@ void CodeGenC::Compile(const lang::LoweredFunc &function) {
2224
os() << arg.name;
2325
};
2426

25-
for (int i = 0; i < function.args.size() - 1; i++) {
26-
print_arg(function.args[i]);
27+
for (int i = 0; i < function->args.size() - 1; i++) {
28+
print_arg(function->args[i]);
2729
os() << ", ";
2830
}
29-
if (function.args.size() >= 1) {
30-
print_arg(function.args.back());
31+
if (function->args.size() >= 1) {
32+
print_arg(function->args.back());
3133
}
3234

3335
os() << ")";
3436

3537
DoIndent();
3638
os() << "{\n";
3739

38-
Print(function.body);
40+
Print(function->body);
3941

4042
DoIndent();
4143
os() << "}";

cinn/backends/codegen_c.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "cinn/ir/function.h"
88
#include "cinn/ir/ir.h"
99
#include "cinn/ir/ir_printer.h"
10+
#include "cinn/ir/lowered_func.h"
1011
#include "cinn/lang/module.h"
1112

1213
namespace cinn {
@@ -24,7 +25,7 @@ class CodeGenC : public ir::IrPrinter {
2425
void Compile(const lang::Module& module);
2526

2627
protected:
27-
void Compile(const lang::LoweredFunc& function);
28+
void Compile(const ir::LoweredFunc& function);
2829
void Compile(const ir::Buffer& buffer);
2930

3031
std::string PrintType(Type type);

cinn/ir/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ set(srcs
77
ir_mutator.cc
88
function.cc
99
function_definition.cc
10+
lowered_func.cc
1011
ir_operators.cc
1112
buffer.cc
1213
function_base.cc

cinn/ir/ir_mutator.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,5 +114,10 @@ void IRMutator::Visit(const _Tensor_ *expr, Expr *op) {
114114
}
115115
}
116116

117+
void IRMutator::Visit(const _LoweredFunc_ *expr, Expr *op) {
118+
auto *node = op->As<_LoweredFunc_>();
119+
IRVisitorBase::Visit(&node->body, &node->body);
120+
}
121+
117122
} // namespace ir
118123
} // namespace cinn

cinn/ir/ir_printer.cc

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22

33
#include <vector>
44

5+
#include "cinn/ir/lowered_func.h"
6+
#include "cinn/lang/module.h"
57
#include "cinn/lang/tensor.h"
8+
#include "cinn/utils/string.h"
69

710
namespace cinn {
811
namespace ir {
@@ -173,6 +176,27 @@ void IrPrinter::Visit(const _Tensor_ *x) {
173176
}
174177
os_ << ")";
175178
}
179+
void IrPrinter::Visit(const _LoweredFunc_ *f) {
180+
os_ << "function " << f->name << " ";
181+
182+
std::vector<std::string> arg_names;
183+
for (auto &arg : f->args) {
184+
arg_names.push_back(arg.name);
185+
}
186+
os_ << "(" << utils::Join(arg_names, ", ");
187+
188+
DoIndent();
189+
os_ << "{";
190+
191+
IncIndent();
192+
193+
Print(f->body);
194+
195+
DecIndent();
196+
197+
DoIndent();
198+
os_ << "}";
199+
}
176200
std::ostream &operator<<(std::ostream &os, Expr a) {
177201
std::stringstream ss;
178202
IrPrinter printer(ss);
@@ -181,5 +205,9 @@ std::ostream &operator<<(std::ostream &os, Expr a) {
181205
return os;
182206
}
183207

208+
std::ostream &operator<<(std::ostream &os, const ir::LoweredFunc &f) {}
209+
210+
std::ostream &operator<<(std::ostream &os, const lang::Module &m);
211+
184212
} // namespace ir
185213
} // namespace cinn

cinn/ir/ir_printer.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@
77
#include "cinn/ir/ir_visitor.h"
88

99
namespace cinn {
10+
11+
namespace lang {
12+
class Module;
13+
class LoweredFunc;
14+
} // namespace lang
15+
1016
namespace ir {
1117

1218
struct IrPrinter : public IRVisitor {
@@ -72,6 +78,7 @@ struct IrPrinter : public IRVisitor {
7278
void Visit(const _IterVar_ *x) override {}
7379
void Visit(const _Buffer_ *x) override;
7480
void Visit(const _Tensor_ *x) override;
81+
void Visit(const _LoweredFunc_ *x) override;
7582

7683
private:
7784
std::ostream &os_;
@@ -80,6 +87,7 @@ struct IrPrinter : public IRVisitor {
8087
};
8188

8289
std::ostream &operator<<(std::ostream &os, Expr a);
90+
std::ostream &operator<<(std::ostream &os, const lang::Module &m);
8391

8492
} // namespace ir
8593
} // namespace cinn

cinn/ir/ir_visitor.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#include "cinn/ir/buffer.h"
66
#include "cinn/ir/ir.h"
7+
#include "cinn/ir/lowered_func.h"
78
#include "cinn/lang/tensor.h"
89

910
namespace cinn {

cinn/ir/lowered_func.cc

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#include "cinn/ir/lowered_func.h"
2+
3+
#include "cinn/common/common.h"
4+
5+
namespace cinn {
6+
namespace ir {
7+
8+
const _LoweredFunc_* LoweredFunc::operator->() const { return As<_LoweredFunc_>(); }
9+
_LoweredFunc_* LoweredFunc::operator->() { return As<_LoweredFunc_>(); }
10+
11+
LoweredFunc _LoweredFunc_::Make(const std::string& name, const std::vector<Argument>& args, const Expr& body) {
12+
auto* n = make_shared<_LoweredFunc_>();
13+
n->name = name;
14+
n->args = args;
15+
n->body = body;
16+
return LoweredFunc(n);
17+
}
18+
19+
LoweredFunc _LoweredFunc_::Make(const std::string& name,
20+
const std::vector<Argument>& args,
21+
const std::vector<Expr>& body) {
22+
CHECK_EQ(body.size(), 1);
23+
return Make(name, args, body.front());
24+
}
25+
26+
std::vector<Expr*> _LoweredFunc_::expr_fields() { return {&body}; }
27+
std::vector<const Expr*> _LoweredFunc_::expr_fields() const { return {&body}; }
28+
29+
} // namespace ir
30+
} // namespace cinn

cinn/ir/lowered_func.h

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
#pragma once
2+
#include "cinn/ir/buffer.h"
3+
#include "cinn/ir/node.h"
4+
5+
namespace cinn {
6+
namespace ir {
7+
8+
class _LoweredFunc_;
9+
10+
/**
11+
* A struct representing an argument to a lowered function. Used for specifying the function signature of generated
12+
* code.
13+
*/
14+
struct Argument {
15+
//! The name of the argument.
16+
std::string name;
17+
18+
enum class Kind { kScalar = 0, kBuffer } kind{Kind::kScalar};
19+
20+
//! Number of the dimensions of buffer.
21+
uint32_t ndims{0};
22+
23+
//! The type of the buffer or scalar.
24+
Type type;
25+
26+
bool is_buffer() const { return kind == Kind::kBuffer; }
27+
bool is_scalar() const { return kind == Kind::kScalar; }
28+
29+
Argument() {}
30+
Argument(const std::string& name, Kind kind, const Type& type, int ndims)
31+
: name(name), kind(kind), type(type), ndims(ndims) {}
32+
33+
explicit Argument(const ir::Buffer& buffer) : name(buffer->name), type(buffer->type()), ndims(buffer->shape.size()) {}
34+
};
35+
36+
//! Wrapper for _LoweredFunc_
37+
class LoweredFunc : public IrNodeRef {
38+
public:
39+
LoweredFunc() = default;
40+
explicit LoweredFunc(IrNode* n) : IrNodeRef(n) {}
41+
42+
const _LoweredFunc_* operator->() const;
43+
_LoweredFunc_* operator->();
44+
};
45+
46+
/**
47+
* Definition of a lowered function. Note that, it should be functional.
48+
*/
49+
struct _LoweredFunc_ : ExprNode<_LoweredFunc_> {
50+
//! The name of this function.
51+
std::string name;
52+
53+
//! The Arguments used in the body of the function.
54+
std::vector<Argument> args;
55+
56+
//! Body of this function.
57+
Expr body;
58+
59+
static LoweredFunc Make(const std::string& name, const std::vector<Argument>& args, const Expr& body);
60+
61+
static LoweredFunc Make(const std::string& name, const std::vector<Argument>& args, const std::vector<Expr>& body);
62+
63+
std::vector<Expr*> expr_fields() override;
64+
std::vector<const Expr*> expr_fields() const override;
65+
66+
static const IrNodeTy _node_type_ = IrNodeTy::_LoweredFunc_;
67+
};
68+
69+
} // namespace ir
70+
} // namespace cinn

0 commit comments

Comments
 (0)