Skip to content

Commit 50275fb

Browse files
authored
Add Instruction and Function class. (PaddlePaddle#17)
* Add Instruction data structure. * Add more methods for Instruction. * Update the cmake dependency of note_ir. * Add Function class for Note IR. * Print attribute values of instruction. * Add the assert statement in FunctionToString UT. * Add the unit test for Function and Instruction. * Add the default action for VisitParameter. * Fix errors of layout_test.cc and shape_test.cc. * Add comments for Instruction and Function. * Add non-const Accept for Instruction class. * Add comments of constructors in Instruction and Function.
1 parent 0422002 commit 50275fb

File tree

17 files changed

+1211
-66
lines changed

17 files changed

+1211
-66
lines changed

paddle/fluid/compiler/piano/backends/note_visitor_base.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,15 @@
1111
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
14+
1415
#pragma once
1516

1617
namespace paddle {
1718
namespace piano {
1819

1920
namespace note {
2021
class Instruction;
21-
}
22+
} // namespace note
2223

2324
namespace backends {
2425

@@ -28,6 +29,8 @@ class NoteVisitorBase {
2829

2930
// Scalar op
3031
virtual void VisitConstant(const note::Instruction&) = 0;
32+
// TODO(sunli): use the pure virtual function instead
33+
virtual void VisitParameter(const note::Instruction&) {}
3134

3235
// ops can be replaced by library
3336
virtual void VisitBatchNormGrad(const note::Instruction&) = 0;

paddle/fluid/compiler/piano/layout.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ std::string Layout::ToString() const {
4141
std::back_inserter(dim_names),
4242
[](const auto& dim) { return std::to_string(dim); });
4343
return paddle::string::format_string(
44-
"{%s}", paddle::string::join_strings(dim_names, ',').c_str());
44+
"{%s}", paddle::string::join_strings(dim_names, ", ").c_str());
4545
}
4646

4747
bool Layout::Valid() const { return !minor_to_major().empty(); }

paddle/fluid/compiler/piano/layout_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ TEST_F(LayoutTest, LayoutTransWithProto) {
3636
}
3737

3838
TEST_F(LayoutTest, LayoutToString) {
39-
ASSERT_EQ("{3,2,1,0}", layout_.ToString());
39+
ASSERT_EQ("{3, 2, 1, 0}", layout_.ToString());
4040
}
4141

4242
} // namespace piano

paddle/fluid/compiler/piano/note/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,6 @@ cc_test(note_opcode_test SRCS opcode_test.cc DEPS note_opcode)
33

44
proto_library(note_proto SRCS note.proto)
55
target_compile_options(note_proto PUBLIC "-Wno-extra")
6+
7+
cc_library(note_ir SRCS instruction.cc function.cc DEPS note_opcode note_proto piano_data_description)
8+
cc_test(note_ir_test SRCS note_ir_test.cc DEPS note_ir)
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/compiler/piano/note/function.h"
16+
#include <memory>
17+
#include <sstream>
18+
#include <unordered_map>
19+
#include <vector>
20+
#include "paddle/fluid/compiler/piano/note/instruction.h"
21+
#include "paddle/fluid/compiler/piano/note/opcode.h"
22+
#include "paddle/fluid/string/string_helper.h"
23+
24+
namespace paddle {
25+
namespace piano {
26+
namespace note {
27+
28+
Function::Function(
29+
const FunctionProto& proto,
30+
const std::unordered_map<std::int64_t, Function*>& func_index)
31+
: name_(proto.name()),
32+
signature_(proto.signature()),
33+
global_id_(proto.id()) {
34+
// the map used to record `id -> Instruction*`
35+
std::unordered_map<std::int64_t, Instruction*> instr_index;
36+
37+
// the map used to record `Instruction* -> id`, which is opposite
38+
// to the instr_index map
39+
std::unordered_map<Instruction*, std::int64_t> inverted_index;
40+
for (const auto& instr_proto : proto.instructions()) {
41+
auto instr =
42+
std::make_unique<Instruction>(instr_proto, instr_index, func_index);
43+
instr->set_parent(this);
44+
// set parameter(input) instructions field
45+
if (instr->opcode() == OpCode::kParameter) {
46+
param_instrs_.push_back(instr.get());
47+
}
48+
instr_index[instr_proto.id()] = instr.get();
49+
inverted_index[instr.get()] = instr_proto.id();
50+
instructions_.emplace_back(std::move(instr));
51+
}
52+
PADDLE_ENFORCE_EQ(
53+
proto.return_id() >= 0 && instr_index.count(proto.return_id()), true,
54+
platform::errors::PreconditionNotMet(
55+
"The return instruction id is %ld, and it is not "
56+
"included in this function.",
57+
proto.return_id()));
58+
59+
// set the returned instruction field
60+
return_instr_ = instr_index[proto.return_id()];
61+
std::sort(instructions_.begin(), instructions_.end(),
62+
[&inverted_index](const std::unique_ptr<Instruction>& l,
63+
const std::unique_ptr<Instruction>& r) {
64+
return inverted_index[l.get()] < inverted_index[r.get()];
65+
});
66+
}
67+
68+
FunctionProto Function::ToProto() const {
69+
FunctionProto proto;
70+
proto.set_name(name_);
71+
*proto.mutable_signature() = signature_.ToProto();
72+
proto.set_id(global_id_);
73+
proto.set_return_id(return_instr_->global_id());
74+
// serialize instruction protos
75+
for (const auto& instr : instructions_) {
76+
*proto.add_instructions() = instr->ToProto();
77+
}
78+
return proto;
79+
}
80+
81+
std::string Function::ToString() const {
82+
std::ostringstream out_str;
83+
// get the function name and signature
84+
out_str << "def %" << name_ << signature_.ToString() << " {\n";
85+
const std::string tab(2, ' ');
86+
87+
// get the string value of each instruction
88+
std::size_t num = instructions_.size();
89+
for (decltype(instructions_.size()) i = 0; i < num; i++) {
90+
if (num - 1 == i) {
91+
out_str << tab << "return ";
92+
} else {
93+
out_str << tab;
94+
}
95+
out_str << instructions_[i]->ToString() << "\n";
96+
}
97+
out_str << "}";
98+
return out_str.str();
99+
}
100+
101+
} // namespace note
102+
} // namespace piano
103+
} // namespace paddle
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <cstddef>
18+
#include <cstdint>
19+
#include <iterator>
20+
#include <memory>
21+
#include <string>
22+
#include <vector>
23+
#include "paddle/fluid/compiler/piano/note/note.pb.h"
24+
#include "paddle/fluid/compiler/piano/shape.h"
25+
26+
namespace paddle {
27+
namespace piano {
28+
namespace note {
29+
30+
class Instruction;
31+
// class Module;
32+
33+
class Function {
34+
public:
35+
// Construct a Function object with a given FunctionProto value.
36+
// 'func_index' is used to transform function id into Function pointer,
37+
// which is used to construct instructions in this function.
38+
Function(const FunctionProto &proto,
39+
const std::unordered_map<std::int64_t, Function *> &func_index);
40+
41+
FunctionProto ToProto() const;
42+
43+
std::string ToString() const;
44+
45+
// return the name of this function
46+
const std::string &name() const { return name_; }
47+
48+
// return instructions owned by this function
49+
std::vector<Instruction *> instructions() const {
50+
std::vector<Instruction *> instrs;
51+
instrs.reserve(instructions_.size());
52+
std::transform(
53+
instructions_.cbegin(), instructions_.cend(),
54+
std::back_inserter(instrs),
55+
[](const std::unique_ptr<Instruction> &instr) { return instr.get(); });
56+
return instrs;
57+
}
58+
59+
const Instruction *instruction(std::int64_t idx) const {
60+
return instructions_.at(idx).get();
61+
}
62+
63+
Instruction *mutable_instruction(std::int64_t idx) {
64+
return instructions_.at(idx).get();
65+
}
66+
67+
// return the function signature
68+
const Signature &signature() const { return signature_; }
69+
70+
Signature *mutable_signature() { return &signature_; }
71+
72+
std::int64_t global_id() const { return global_id_; }
73+
74+
// return the returned instruction of this function
75+
const Instruction *return_instr() const { return return_instr_; }
76+
77+
// const Module *parent() const { return parent_; }
78+
79+
// Module *mutable_parent() { return parent_; }
80+
81+
// void set_parent(Module *module) { parent_ = module; }
82+
83+
const std::vector<Instruction *> &param_instrs() const {
84+
return param_instrs_;
85+
}
86+
87+
// return parameter instructions of this function
88+
const Instruction *param_instr(std::int64_t idx) const {
89+
return param_instrs_.at(idx);
90+
}
91+
92+
// return the parameter(input) number of this function
93+
std::size_t params_num() const { return param_instrs_.size(); }
94+
95+
private:
96+
// the name of this function
97+
std::string name_;
98+
// instructions owned by this function
99+
std::vector<std::unique_ptr<Instruction>> instructions_;
100+
// the function signature, including parameter and return types
101+
Signature signature_;
102+
// the global id of this function in a module
103+
std::int64_t global_id_;
104+
// the returned instruction of this function
105+
Instruction *return_instr_;
106+
107+
// TODO(wzzju): Add Module class.
108+
// the module where this function is contained
109+
// Module *parent_{nullptr};
110+
111+
// parameter instructions of this function,
112+
// which denote input parameters
113+
std::vector<Instruction *> param_instrs_;
114+
115+
DISABLE_COPY_AND_ASSIGN(Function);
116+
};
117+
118+
} // namespace note
119+
} // namespace piano
120+
} // namespace paddle

0 commit comments

Comments
 (0)