Skip to content

Commit 167523e

Browse files
authored
graph_to_program topology sort (#33949)
See #33949 for details
1 parent f1654de commit 167523e

11 files changed

Lines changed: 762 additions & 33 deletions

File tree

paddle/fluid/framework/ir/graph.cc

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,12 @@ Graph::Graph(const ProgramDesc &program, const int64_t start_op_index,
5656
// sub_graph.
5757
std::unique_ptr<Graph> first_sub_graph = std::make_unique<Graph>(
5858
program_.Block(0), this, start_op_index, end_op_index);
59+
first_sub_graph->block_id_ = 0;
5960
sub_graphs_.push_back(std::move(first_sub_graph));
6061
for (size_t idx = 1; idx < program_.Size(); ++idx) {
6162
std::unique_ptr<Graph> sub_graph =
6263
std::make_unique<Graph>(program_.Block(idx), this);
64+
sub_graph->block_id_ = idx;
6365
sub_graphs_.push_back(std::move(sub_graph));
6466
}
6567
} else {
@@ -90,14 +92,32 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromProgram(
9092
std::map<std::string, std::vector<ir::Node *>> Graph::InitFromBlock(
9193
const BlockDesc &block, const int64_t start_op_index,
9294
const int64_t end_op_index) {
93-
std::unordered_map<std::string, VarDesc *> all_vars;
95+
std::unordered_map<std::string, std::pair<VarDesc *, int>>
96+
name_to_desc_block_id;
97+
98+
const BlockDesc *block_var_visible = &block;
99+
while (block_var_visible != nullptr) {
100+
for (auto *var : block_var_visible->AllVars()) {
101+
name_to_desc_block_id.emplace(
102+
var->Name(), std::make_pair(var, block_var_visible->ID()));
103+
}
104+
const BlockDesc *forward_block = block_var_visible->ForwardBlock();
105+
if (forward_block != nullptr) {
106+
for (auto *var : forward_block->AllVars()) {
107+
name_to_desc_block_id.emplace(var->Name(),
108+
std::make_pair(var, forward_block->ID()));
109+
}
110+
}
111+
block_var_visible = block_var_visible->ParentBlock();
112+
}
94113
// var nodes for each var name, will have multiple versions in SSA
95114
std::map<std::string, std::vector<ir::Node *>> var_nodes;
115+
std::unordered_map<std::string, VarDesc *> not_visited_vars;
96116
for (auto *var : block.AllVars()) {
97-
all_vars.emplace(var->Name(), var);
117+
not_visited_vars.emplace(var->Name(), var);
98118
}
99119

100-
auto not_visited_vars = all_vars;
120+
int desc_order = 0;
101121
auto all_ops = block.AllOps();
102122
PADDLE_ENFORCE_LE(
103123
end_op_index, all_ops.size(),
@@ -109,15 +129,18 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromBlock(
109129
auto *op = all_ops[i];
110130
VLOG(3) << "create OpNode by " << op->Type();
111131
ir::Node *node = CreateOpNode(op);
132+
node->SetDescOrder(desc_order);
133+
++desc_order;
112134
// For input args, reuse the same var name if it was created before.
113135
// Otherwise, create a new one.
114136
for (auto &each_var_name : op->InputArgumentNames()) {
115137
not_visited_vars.erase(each_var_name);
116138
ir::Node *var = nullptr;
117139
if (var_nodes.find(each_var_name) != var_nodes.end()) {
118140
var = var_nodes.at(each_var_name).back();
119-
} else if (all_vars.count(each_var_name) != 0) {
120-
var = CreateVarNode(all_vars.at(each_var_name));
141+
} else if (name_to_desc_block_id.count(each_var_name) != 0) {
142+
auto desc_and_block_id = name_to_desc_block_id.at(each_var_name);
143+
var = CreateVarNode(desc_and_block_id.first, desc_and_block_id.second);
121144
var_nodes[each_var_name].push_back(var);
122145
} else {
123146
// Operation input var can be optional (dispensable). Which means
@@ -143,8 +166,9 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromBlock(
143166
}
144167

145168
ir::Node *var = nullptr;
146-
if (all_vars.count(each_var_name) != 0) {
147-
var = CreateVarNode(all_vars.at(each_var_name));
169+
if (name_to_desc_block_id.count(each_var_name) != 0) {
170+
auto desc_and_block_id = name_to_desc_block_id.at(each_var_name);
171+
var = CreateVarNode(desc_and_block_id.first, desc_and_block_id.second);
148172
} else {
149173
// Operation output vars can be @EMPTY@. For example, while_grad
150174
// can have multi @EMPTY@ outputs with no VarDesc.
@@ -270,6 +294,7 @@ std::shared_ptr<Graph> Graph::Clone() {
270294
auto cloned_graph = std::make_shared<Graph>(this->program_);
271295
cloned_graph->ReleaseNodes();
272296
cloned_graph->num_node_created_ = 0;
297+
cloned_graph->block_id_ = this->block_id_;
273298
std::unordered_map<ir::Node *, ir::Node *> origin_to_cloned;
274299
for (auto *n : this->node_set_) {
275300
PADDLE_ENFORCE_NOT_NULL(n, platform::errors::InvalidArgument(
@@ -313,6 +338,7 @@ std::unique_ptr<Graph> Graph::CloneSubGraph(const size_t idx) {
313338
std::make_unique<Graph>(this->program_.Block(idx), this);
314339
cloned_sub_graph->ReleaseNodes();
315340
cloned_sub_graph->num_node_created_ = 0;
341+
cloned_sub_graph->block_id_ = idx;
316342
std::unordered_map<ir::Node *, ir::Node *> origin_to_cloned;
317343
for (auto *n : this->sub_graphs_.at(idx)->Nodes()) {
318344
PADDLE_ENFORCE_NOT_NULL(n, platform::errors::InvalidArgument(

paddle/fluid/framework/ir/graph.h

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,14 @@ class Graph {
104104
attr_dels_.clear();
105105
}
106106

107-
bool IsConstructedByPartialProgram() const { return is_partial_; }
107+
bool IsConstructedByPartialProgram() const {
108+
if (FLAGS_convert_all_blocks) {
109+
if (IsMainGraph()) {
110+
return GetSubGraph(0)->IsConstructedByPartialProgram();
111+
}
112+
}
113+
return is_partial_;
114+
}
108115

109116
bool Has(const std::string &attr_name) const {
110117
if (FLAGS_convert_all_blocks) {
@@ -210,7 +217,7 @@ class Graph {
210217
}
211218

212219
// Create a normal variable with non-null VarDesc.
213-
ir::Node *CreateVarNode(VarDesc *var_desc) {
220+
ir::Node *CreateVarNode(VarDesc *var_desc, int block_id = -1) {
214221
if (FLAGS_convert_all_blocks) {
215222
if (IsMainGraph()) {
216223
return GetSubGraph(0)->CreateVarNode(var_desc);
@@ -219,7 +226,8 @@ class Graph {
219226
PADDLE_ENFORCE_NOT_NULL(
220227
var_desc, platform::errors::InvalidArgument(
221228
"The VarDesc used to create variable node is null."));
222-
auto *x = AddNode(new ir::Node(var_desc));
229+
auto *x =
230+
AddNode(new ir::Node(var_desc, block_id == -1 ? block_id_ : block_id));
223231
x->SetId(num_node_created_++);
224232
return x;
225233
}
@@ -252,7 +260,7 @@ class Graph {
252260
const std::string name = string::Sprintf(
253261
"%s@%llu", static_cast<const char *>(ir::Node::kControlDepVarName),
254262
num_node_created_);
255-
auto *x = AddNode(new ir::Node(name, ir::Node::Type::kVariable));
263+
auto *x = AddNode(new ir::Node(name, ir::Node::Type::kVariable, block_id_));
256264
x->SetId(num_node_created_++);
257265
return x;
258266
}
@@ -265,7 +273,7 @@ class Graph {
265273
return GetSubGraph(0)->CreateEmptyNode(name, type);
266274
}
267275
}
268-
auto *x = AddNode(new ir::Node(name, type));
276+
auto *x = AddNode(new ir::Node(name, type, block_id_));
269277
x->SetId(num_node_created_++);
270278
return x;
271279
}
@@ -365,6 +373,15 @@ class Graph {
365373
return sub_graphs_.at(idx).get();
366374
}
367375

376+
int GetBlockId() const {
377+
if (FLAGS_convert_all_blocks) {
378+
if (IsMainGraph()) {
379+
return GetSubGraph(0)->block_id_;
380+
}
381+
}
382+
return block_id_;
383+
}
384+
368385
size_t SubGraphsSize() const {
369386
PADDLE_ENFORCE_EQ(
370387
this->IsMainGraph(), true,
@@ -394,6 +411,9 @@ class Graph {
394411
PADDLE_ENFORCE_EQ(
395412
this->IsMainGraph(), true,
396413
platform::errors::InvalidArgument("This graph is not main_graph"));
414+
PADDLE_ENFORCE_EQ(sub_graphs_.size(), sub_graph->block_id_,
415+
platform::errors::InvalidArgument(
416+
"sub_graph idx is not equal to block_id_"));
397417
sub_graphs_.push_back(std::move(sub_graph));
398418
}
399419

@@ -416,6 +436,8 @@ class Graph {
416436
// parts: forward graph and backward graph, which can be executed
417437
// independently.
418438
bool is_partial_{false};
439+
// The block this SubGraph belongs to.
440+
int block_id_{0};
419441
};
420442

421443
bool IsControlDepVar(const ir::Node &var);

paddle/fluid/framework/ir/graph_helper.cc

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/framework/ir/graph_helper.h"
16+
#include <queue>
1617
#include <stack>
1718

1819
DEFINE_string(print_sub_graph_dir, "",
@@ -395,6 +396,85 @@ std::vector<Node *> TopologyVarientSort(const Graph &graph,
395396
}
396397
}
397398

399+
class DescOrderComparator {
400+
public:
401+
bool operator()(const Node *n1, const Node *n2) {
402+
return (n1->DescOrder() > n2->DescOrder()) ||
403+
((n1->DescOrder() == n2->DescOrder()) &&
404+
(n1->ToString() > n2->ToString()));
405+
}
406+
};
407+
408+
std::vector<ir::Node *> TopologySortGraphByDescOrder(const Graph &graph) {
409+
std::vector<ir::Node *> sorted_ops;
410+
std::priority_queue<Node *, std::vector<Node *>, DescOrderComparator> q;
411+
std::unordered_map<Node *, std::unordered_set<Node *>> in_ops;
412+
std::unordered_map<Node *, std::unordered_set<Node *>> out_ops;
413+
414+
// ensure all op node in 'in_ops' and 'out_ops'
415+
for (const auto &n : graph.Nodes()) {
416+
if (!n->IsOp()) continue;
417+
418+
in_ops.emplace(n, std::unordered_set<Node *>());
419+
out_ops.emplace(n, std::unordered_set<Node *>());
420+
}
421+
422+
// record all op's input op and output op
423+
for (const auto &n : graph.Nodes()) {
424+
if (!n->IsOp()) continue;
425+
426+
// traverse all input op
427+
for (const auto &var : n->inputs) {
428+
for (const auto &in : var->inputs) {
429+
// use at instead of [] to prevent no unrecorded op node
430+
in_ops.at(n).insert(in);
431+
out_ops.at(in).insert(n);
432+
}
433+
}
434+
}
435+
436+
// find topology entrance
437+
for (const auto &n : graph.Nodes()) {
438+
if (!n->IsOp()) continue;
439+
440+
if (in_ops.at(n).empty()) {
441+
q.push(n);
442+
}
443+
}
444+
445+
// topological sorting
446+
while (!q.empty()) {
447+
// Do not get by reference!!! The element will pop later.
448+
const auto cur_op = q.top();
449+
q.pop();
450+
451+
sorted_ops.push_back(cur_op);
452+
for (const auto &out : out_ops.at(cur_op)) {
453+
PADDLE_ENFORCE_GT(in_ops.at(out).count(cur_op), 0,
454+
platform::errors::InvalidArgument(
455+
"We find %s in %s's output list, "
456+
"but cannot find %s in %s's input list. "
457+
"Please ensure graph completely.",
458+
out->Name().c_str(), cur_op->Name().c_str(),
459+
cur_op->Name().c_str(), out->Name().c_str()));
460+
in_ops.at(out).erase(cur_op);
461+
462+
// push if in-degree is 0
463+
if (in_ops.at(out).empty()) {
464+
q.push(out);
465+
}
466+
}
467+
}
468+
469+
PADDLE_ENFORCE_EQ(
470+
sorted_ops.size(), in_ops.size(),
471+
platform::errors::InvalidArgument("Topological sorting incompletely, "
472+
"only sorted %zd op but total %zd.",
473+
sorted_ops.size(), in_ops.size()));
474+
475+
return sorted_ops;
476+
}
477+
398478
} // namespace ir
399479
} // namespace framework
400480
} // namespace paddle

paddle/fluid/framework/ir/graph_helper.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ std::vector<T *> FilterByNodeWrapper(const Graph &graph) {
8787
return ret;
8888
}
8989

90+
std::vector<ir::Node *> TopologySortGraphByDescOrder(const Graph &graph);
91+
9092
} // namespace ir
9193
} // namespace framework
9294
} // namespace paddle

0 commit comments

Comments
 (0)