Skip to content

Commit 424c295

Browse files
authored
Merge pull request #4457 from Canpio/dev_move_desc_to_framework
move Protobuf desc to framework
2 parents 21f63ec + f78d759 commit 424c295

File tree

12 files changed

+645
-386
lines changed

12 files changed

+645
-386
lines changed

paddle/framework/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ cc_test(scope_test SRCS scope_test.cc DEPS scope)
1919
proto_library(framework_proto SRCS framework.proto)
2020

2121
cc_library(attribute SRCS attribute.cc DEPS framework_proto)
22+
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute)
2223
cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute)
2324
cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker)
2425
cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto)

paddle/framework/block_desc.cc

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/framework/block_desc.h"
16+
#include "paddle/framework/program_desc.h"
17+
18+
namespace paddle {
19+
namespace framework {
20+
21+
VarDescBind *BlockDescBind::NewVar(const std::string &name) {
22+
need_update_ = true;
23+
auto it = vars_.find(name);
24+
PADDLE_ENFORCE(it == vars_.end(), "Duplicated variable %s", name);
25+
auto var = new VarDescBind(name);
26+
vars_[name].reset(var);
27+
return var;
28+
}
29+
30+
VarDescBind *BlockDescBind::Var(const std::string &name) const {
31+
auto it = vars_.find(name);
32+
PADDLE_ENFORCE(it != vars_.end(),
33+
"Can not find variable %s in current block.", name);
34+
return it->second.get();
35+
}
36+
37+
std::vector<VarDescBind *> BlockDescBind::AllVars() const {
38+
std::vector<VarDescBind *> res;
39+
for (const auto &p : vars_) {
40+
res.push_back(p.second.get());
41+
}
42+
return res;
43+
}
44+
45+
OpDescBind *BlockDescBind::AppendOp() {
46+
need_update_ = true;
47+
ops_.emplace_back(new OpDescBind());
48+
return ops_.back().get();
49+
}
50+
51+
OpDescBind *BlockDescBind::PrependOp() {
52+
need_update_ = true;
53+
ops_.emplace_front(new OpDescBind());
54+
return ops_.front().get();
55+
}
56+
57+
std::vector<OpDescBind *> BlockDescBind::AllOps() const {
58+
std::vector<OpDescBind *> res;
59+
for (const auto &op : ops_) {
60+
res.push_back(op.get());
61+
}
62+
return res;
63+
}
64+
65+
void BlockDescBind::Sync() {
66+
if (need_update_) {
67+
auto &op_field = *this->desc_->mutable_ops();
68+
op_field.Clear();
69+
op_field.Reserve(static_cast<int>(ops_.size()));
70+
for (auto &op_desc : ops_) {
71+
op_field.AddAllocated(op_desc->Proto());
72+
}
73+
need_update_ = false;
74+
}
75+
}
76+
77+
BlockDescBind *BlockDescBind::ParentBlock() const {
78+
if (this->desc_->parent_idx() == -1) {
79+
return nullptr;
80+
}
81+
return prog_->Block(static_cast<size_t>(this->desc_->parent_idx()));
82+
}
83+
84+
void OpDescBind::SetBlockAttr(const std::string &name, BlockDescBind &block) {
85+
BlockDesc *desc = block.RawPtr();
86+
this->attrs_[name] = desc;
87+
}
88+
} // namespace framework
89+
} // namespace paddle

paddle/framework/block_desc.h

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <deque>
18+
#include <unordered_map>
19+
#include <vector>
20+
#include "paddle/framework/op_desc.h"
21+
#include "paddle/framework/var_desc.h"
22+
23+
namespace paddle {
24+
namespace framework {
25+
26+
class ProgramDescBind;
27+
28+
// Each Protobuf Message, we provide a XXXBind class. In that class, we optimize
29+
// read/write speed. Only when we want the protobuf message, the local changes
30+
// will be synchronized (by `Sync` method).
31+
32+
class BlockDescBind {
33+
public:
34+
BlockDescBind(ProgramDescBind *prog, BlockDesc *desc)
35+
: prog_(prog), desc_(desc), need_update_(false) {}
36+
37+
BlockDescBind(const BlockDescBind &o) = delete;
38+
BlockDescBind &operator=(const BlockDescBind &o) = delete;
39+
40+
int32_t ID() const { return desc_->idx(); }
41+
42+
int32_t Parent() const { return desc_->parent_idx(); }
43+
44+
VarDescBind *NewVar(const std::string &name_bytes);
45+
46+
VarDescBind *Var(const std::string &name_bytes) const;
47+
48+
std::vector<VarDescBind *> AllVars() const;
49+
50+
BlockDescBind *ParentBlock() const;
51+
52+
OpDescBind *AppendOp();
53+
54+
OpDescBind *PrependOp();
55+
56+
std::vector<OpDescBind *> AllOps() const;
57+
58+
void Sync();
59+
60+
BlockDesc *RawPtr() { return desc_; }
61+
62+
private:
63+
ProgramDescBind *prog_; // not_own
64+
BlockDesc *desc_; // not_own
65+
bool need_update_;
66+
67+
std::deque<std::unique_ptr<OpDescBind>> ops_;
68+
std::unordered_map<std::string, std::unique_ptr<VarDescBind>> vars_;
69+
};
70+
} // namespace framework
71+
} // namespace paddle

paddle/framework/op_desc.cc

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/framework/op_desc.h"
16+
#include "paddle/framework/block_desc.h"
17+
18+
namespace paddle {
19+
namespace framework {
20+
21+
OpDesc *OpDescBind::Proto() {
22+
Sync();
23+
return &op_desc_;
24+
}
25+
26+
const std::vector<std::string> &OpDescBind::Input(
27+
const std::string &name) const {
28+
auto it = inputs_.find(name);
29+
PADDLE_ENFORCE(it != inputs_.end(), "Input %s cannot be found in Op %s", name,
30+
Type());
31+
return it->second;
32+
}
33+
34+
std::vector<std::string> OpDescBind::InputNames() const {
35+
std::vector<std::string> retv;
36+
retv.reserve(this->inputs_.size());
37+
for (auto &ipt : this->inputs_) {
38+
retv.push_back(ipt.first);
39+
}
40+
return retv;
41+
}
42+
43+
void OpDescBind::SetInput(const std::string &param_name,
44+
const std::vector<std::string> &args) {
45+
need_update_ = true;
46+
inputs_[param_name] = args;
47+
}
48+
49+
const std::vector<std::string> &OpDescBind::Output(
50+
const std::string &name) const {
51+
auto it = outputs_.find(name);
52+
PADDLE_ENFORCE(it != outputs_.end(), "Output %s cannot be found in Op %s",
53+
name, Type());
54+
return it->second;
55+
}
56+
57+
std::vector<std::string> OpDescBind::OutputNames() const {
58+
std::vector<std::string> retv;
59+
retv.reserve(this->outputs_.size());
60+
for (auto &ipt : this->outputs_) {
61+
retv.push_back(ipt.first);
62+
}
63+
return retv;
64+
}
65+
66+
void OpDescBind::SetOutput(const std::string &param_name,
67+
const std::vector<std::string> &args) {
68+
need_update_ = true;
69+
this->outputs_[param_name] = args;
70+
}
71+
72+
AttrType OpDescBind::GetAttrType(const std::string &name) const {
73+
auto it = attrs_.find(name);
74+
PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
75+
return static_cast<AttrType>(it->second.which() - 1);
76+
}
77+
78+
std::vector<std::string> OpDescBind::AttrNames() const {
79+
std::vector<std::string> retv;
80+
retv.reserve(attrs_.size());
81+
for (auto &attr : attrs_) {
82+
retv.push_back(attr.first);
83+
}
84+
return retv;
85+
}
86+
87+
void OpDescBind::SetAttr(const std::string &name, const Attribute &v) {
88+
this->attrs_[name] = v;
89+
need_update_ = true;
90+
}
91+
92+
Attribute OpDescBind::GetAttr(const std::string &name) const {
93+
auto it = attrs_.find(name);
94+
PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
95+
return it->second;
96+
}
97+
98+
int OpDescBind::GetBlockAttr(const std::string &name) const {
99+
auto it = attrs_.find(name);
100+
PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
101+
return boost::get<BlockDesc *>(it->second)->idx();
102+
}
103+
104+
void OpDescBind::Sync() {
105+
if (need_update_) {
106+
this->op_desc_.mutable_inputs()->Clear();
107+
for (auto &ipt : inputs_) {
108+
auto *input = op_desc_.add_inputs();
109+
input->set_parameter(ipt.first);
110+
VectorToRepeated(ipt.second, input->mutable_arguments());
111+
}
112+
113+
this->op_desc_.mutable_outputs()->Clear();
114+
for (auto &opt : outputs_) {
115+
auto *output = op_desc_.add_outputs();
116+
output->set_parameter(opt.first);
117+
VectorToRepeated(opt.second, output->mutable_arguments());
118+
}
119+
120+
this->op_desc_.mutable_attrs()->Clear();
121+
for (auto &attr : attrs_) {
122+
auto *attr_desc = op_desc_.add_attrs();
123+
attr_desc->set_name(attr.first);
124+
attr_desc->set_type(
125+
static_cast<framework::AttrType>(attr.second.which() - 1));
126+
boost::apply_visitor(SetAttrDescVisitor(attr_desc), attr.second);
127+
}
128+
129+
need_update_ = false;
130+
}
131+
}
132+
} // namespace framework
133+
} // namespace paddle

0 commit comments

Comments
 (0)