Skip to content

Commit 45ed5ba

Browse files
committed
merge branch zhhsplendid:ci_test_convert_all_block PaddlePaddle#34289
1 parent 4fa0bdf commit 45ed5ba

File tree

7 files changed

+143
-44
lines changed

7 files changed

+143
-44
lines changed

paddle/fluid/framework/ir/graph.cc

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,12 @@ Graph::Graph(const ProgramDesc &program, const int64_t start_op_index,
5656
// sub_graph.
5757
std::unique_ptr<Graph> first_sub_graph = std::make_unique<Graph>(
5858
program_.Block(0), this, start_op_index, end_op_index);
59+
first_sub_graph->block_id_ = 0;
5960
sub_graphs_.push_back(std::move(first_sub_graph));
6061
for (size_t idx = 1; idx < program_.Size(); ++idx) {
6162
std::unique_ptr<Graph> sub_graph =
6263
std::make_unique<Graph>(program_.Block(idx), this);
64+
sub_graph->block_id_ = idx;
6365
sub_graphs_.push_back(std::move(sub_graph));
6466
}
6567
} else {
@@ -90,15 +92,32 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromProgram(
9092
std::map<std::string, std::vector<ir::Node *>> Graph::InitFromBlock(
9193
const BlockDesc &block, const int64_t start_op_index,
9294
const int64_t end_op_index) {
93-
std::unordered_map<std::string, VarDesc *> all_vars;
95+
std::unordered_map<std::string, std::pair<VarDesc *, int>>
96+
name_to_desc_block_id;
97+
98+
const BlockDesc *block_var_visible = &block;
99+
while (block_var_visible != nullptr) {
100+
for (auto *var : block_var_visible->AllVars()) {
101+
name_to_desc_block_id.emplace(
102+
var->Name(), std::make_pair(var, block_var_visible->ID()));
103+
}
104+
const BlockDesc *forward_block = block_var_visible->ForwardBlock();
105+
if (forward_block != nullptr) {
106+
for (auto *var : forward_block->AllVars()) {
107+
name_to_desc_block_id.emplace(var->Name(),
108+
std::make_pair(var, forward_block->ID()));
109+
}
110+
}
111+
block_var_visible = block_var_visible->ParentBlock();
112+
}
94113
// var nodes for each var name, will have multiple versions in SSA
95114
std::map<std::string, std::vector<ir::Node *>> var_nodes;
115+
std::unordered_map<std::string, VarDesc *> not_visited_vars;
96116
for (auto *var : block.AllVars()) {
97-
all_vars.emplace(var->Name(), var);
117+
not_visited_vars.emplace(var->Name(), var);
98118
}
99119

100120
int desc_order = 0;
101-
auto not_visited_vars = all_vars;
102121
auto all_ops = block.AllOps();
103122
PADDLE_ENFORCE_LE(
104123
end_op_index, all_ops.size(),
@@ -119,8 +138,9 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromBlock(
119138
ir::Node *var = nullptr;
120139
if (var_nodes.find(each_var_name) != var_nodes.end()) {
121140
var = var_nodes.at(each_var_name).back();
122-
} else if (all_vars.count(each_var_name) != 0) {
123-
var = CreateVarNode(all_vars.at(each_var_name));
141+
} else if (name_to_desc_block_id.count(each_var_name) != 0) {
142+
auto desc_and_block_id = name_to_desc_block_id.at(each_var_name);
143+
var = CreateVarNode(desc_and_block_id.first, desc_and_block_id.second);
124144
var_nodes[each_var_name].push_back(var);
125145
} else {
126146
// Operation input var can be optional (dispensable). Which means
@@ -146,8 +166,9 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromBlock(
146166
}
147167

148168
ir::Node *var = nullptr;
149-
if (all_vars.count(each_var_name) != 0) {
150-
var = CreateVarNode(all_vars.at(each_var_name));
169+
if (name_to_desc_block_id.count(each_var_name) != 0) {
170+
auto desc_and_block_id = name_to_desc_block_id.at(each_var_name);
171+
var = CreateVarNode(desc_and_block_id.first, desc_and_block_id.second);
151172
} else {
152173
// Operation output vars can be @EMPTY@. For example, while_grad
153174
// can have multi @EMPTY@ outputs with no VarDesc.
@@ -273,6 +294,7 @@ std::shared_ptr<Graph> Graph::Clone() {
273294
auto cloned_graph = std::make_shared<Graph>(this->program_);
274295
cloned_graph->ReleaseNodes();
275296
cloned_graph->num_node_created_ = 0;
297+
cloned_graph->block_id_ = this->block_id_;
276298
std::unordered_map<ir::Node *, ir::Node *> origin_to_cloned;
277299
for (auto *n : this->node_set_) {
278300
PADDLE_ENFORCE_NOT_NULL(n, platform::errors::InvalidArgument(
@@ -316,6 +338,7 @@ std::unique_ptr<Graph> Graph::CloneSubGraph(const size_t idx) {
316338
std::make_unique<Graph>(this->program_.Block(idx), this);
317339
cloned_sub_graph->ReleaseNodes();
318340
cloned_sub_graph->num_node_created_ = 0;
341+
cloned_sub_graph->block_id_ = idx;
319342
std::unordered_map<ir::Node *, ir::Node *> origin_to_cloned;
320343
for (auto *n : this->sub_graphs_.at(idx)->Nodes()) {
321344
PADDLE_ENFORCE_NOT_NULL(n, platform::errors::InvalidArgument(

paddle/fluid/framework/ir/graph.h

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,14 @@ class Graph {
104104
attr_dels_.clear();
105105
}
106106

107-
bool IsConstructedByPartialProgram() const { return is_partial_; }
107+
bool IsConstructedByPartialProgram() const {
108+
if (FLAGS_convert_all_blocks) {
109+
if (IsMainGraph()) {
110+
return GetSubGraph(0)->IsConstructedByPartialProgram();
111+
}
112+
}
113+
return is_partial_;
114+
}
108115

109116
bool Has(const std::string &attr_name) const {
110117
if (FLAGS_convert_all_blocks) {
@@ -210,7 +217,7 @@ class Graph {
210217
}
211218

212219
// Create a normal variable with non-null VarDesc.
213-
ir::Node *CreateVarNode(VarDesc *var_desc) {
220+
ir::Node *CreateVarNode(VarDesc *var_desc, int block_id = -1) {
214221
if (FLAGS_convert_all_blocks) {
215222
if (IsMainGraph()) {
216223
return GetSubGraph(0)->CreateVarNode(var_desc);
@@ -219,7 +226,8 @@ class Graph {
219226
PADDLE_ENFORCE_NOT_NULL(
220227
var_desc, platform::errors::InvalidArgument(
221228
"The VarDesc used to create variable node is null."));
222-
auto *x = AddNode(new ir::Node(var_desc));
229+
auto *x =
230+
AddNode(new ir::Node(var_desc, block_id == -1 ? block_id_ : block_id));
223231
x->SetId(num_node_created_++);
224232
return x;
225233
}
@@ -252,7 +260,7 @@ class Graph {
252260
const std::string name = string::Sprintf(
253261
"%s@%llu", static_cast<const char *>(ir::Node::kControlDepVarName),
254262
num_node_created_);
255-
auto *x = AddNode(new ir::Node(name, ir::Node::Type::kVariable));
263+
auto *x = AddNode(new ir::Node(name, ir::Node::Type::kVariable, block_id_));
256264
x->SetId(num_node_created_++);
257265
return x;
258266
}
@@ -265,7 +273,7 @@ class Graph {
265273
return GetSubGraph(0)->CreateEmptyNode(name, type);
266274
}
267275
}
268-
auto *x = AddNode(new ir::Node(name, type));
276+
auto *x = AddNode(new ir::Node(name, type, block_id_));
269277
x->SetId(num_node_created_++);
270278
return x;
271279
}
@@ -365,6 +373,15 @@ class Graph {
365373
return sub_graphs_.at(idx).get();
366374
}
367375

376+
int GetBlockId() const {
377+
if (FLAGS_convert_all_blocks) {
378+
if (IsMainGraph()) {
379+
return GetSubGraph(0)->block_id_;
380+
}
381+
}
382+
return block_id_;
383+
}
384+
368385
size_t SubGraphsSize() const {
369386
PADDLE_ENFORCE_EQ(
370387
this->IsMainGraph(), true,
@@ -394,6 +411,9 @@ class Graph {
394411
PADDLE_ENFORCE_EQ(
395412
this->IsMainGraph(), true,
396413
platform::errors::InvalidArgument("This graph is not main_graph"));
414+
PADDLE_ENFORCE_EQ(sub_graphs_.size(), sub_graph->block_id_,
415+
platform::errors::InvalidArgument(
416+
"sub_graph idx is not equal to block_id_"));
397417
sub_graphs_.push_back(std::move(sub_graph));
398418
}
399419

@@ -416,6 +436,8 @@ class Graph {
416436
// parts: forward graph and backward graph, which can be executed
417437
// independently.
418438
bool is_partial_{false};
439+
// The block this SubGraph belongs to.
440+
int block_id_{0};
419441
};
420442

421443
bool IsControlDepVar(const ir::Node &var);

paddle/fluid/framework/ir/graph_to_program_pass.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,8 @@ void GraphToProgramPass::GraphToBlock(const Graph* graph,
116116
for (ir::Node* n : graph->Nodes()) {
117117
if (n->IsVar()) {
118118
if (n->Var() && visited_vars.count(n->Var()->Name()) == 0 &&
119-
!vars2remove.count(n->Var()->Name())) {
119+
!vars2remove.count(n->Var()->Name()) &&
120+
n->GetVarNodeBlockId() == graph->GetBlockId()) {
120121
visited_vars.insert(n->Var()->Name());
121122
block->add_vars()->MergeFrom(*n->Var()->Proto());
122123
}

paddle/fluid/framework/ir/memory_optimize_pass/while_op_eager_deletion_pass.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,13 @@ using OpVariant = operators::OpVariant;
2626
class WhileOpEagerDeletionPass : public ir::Pass {
2727
protected:
2828
void ApplyImpl(ir::Graph *graph) const override {
29+
if (!graph->IsMainGraph()) {
30+
// TODO(zhhsplendid): the WhileOpEagerDeletionPass is based on old Graph,
31+
// which only applies to the main block graph. The new Eager Deletion
32+
// Technical can be added after we write new while_op based on SubGraph
33+
// instead of SubBlock
34+
return;
35+
}
2936
auto all_ops = ir::FilterByNodeWrapper<details::OpHandleBase>(*graph);
3037

3138
// Find all while_op and while_grad_op. In case of @to_static, graph
@@ -47,6 +54,7 @@ class WhileOpEagerDeletionPass : public ir::Pass {
4754
}
4855
}
4956
if (graph->IsConstructedByPartialProgram()) {
57+
VLOG(4) << "Is Paritial Program";
5058
PADDLE_ENFORCE_LE(
5159
target_ops.size(), 1,
5260
platform::errors::InvalidArgument(
@@ -69,8 +77,11 @@ class WhileOpEagerDeletionPass : public ir::Pass {
6977
}
7078

7179
for (auto &ops_pair : target_ops) {
80+
VLOG(4) << "Scope Idx = " << ops_pair.first;
7281
auto &while_ops = ops_pair.second.first;
82+
VLOG(4) << "while_ops.size() = " << while_ops.size();
7383
auto &while_grad_ops = ops_pair.second.second;
84+
VLOG(4) << "while_grad_ops.size() = " << while_grad_ops.size();
7485
operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
7586
graph->OriginProgram(), while_ops, while_grad_ops);
7687
}

paddle/fluid/framework/ir/node.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ std::unique_ptr<Node> CreateNodeForTest(const std::string &name,
3030
}
3131

3232
std::unique_ptr<Node> CreateNodeForTest(VarDesc *var_desc) {
33-
return std::unique_ptr<Node>(new Node(var_desc));
33+
return std::unique_ptr<Node>(new Node(var_desc, 0));
3434
}
3535

3636
std::unique_ptr<Node> CreateNodeForTest(OpDesc *op_desc) {

paddle/fluid/framework/ir/node.h

Lines changed: 52 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,13 @@ class Node {
138138

139139
int DescOrder() const { return desc_order_; }
140140

141+
int GetVarNodeBlockId() const {
142+
PADDLE_ENFORCE_EQ(
143+
type_ == Type::kVariable && var_desc_, true,
144+
platform::errors::InvalidArgument("Node must be type of variable."));
145+
return block_id_;
146+
}
147+
141148
const std::string ToString() const {
142149
if (IsOp()) {
143150
std::string op_str(Name());
@@ -158,27 +165,53 @@ class Node {
158165
comparator);
159166
std::stable_sort(sorted_outputs.begin(), sorted_outputs.end(),
160167
comparator);
161-
for (const auto& input : sorted_inputs) {
162-
op_str.append(input->Name());
163-
}
168+
169+
std::string out_str = "{";
170+
std::string pre_str = "";
164171
for (const auto& output : sorted_outputs) {
165-
op_str.append(output->Name());
172+
out_str.append(pre_str + output->Name());
173+
pre_str = ", ";
174+
}
175+
out_str.append("} = ");
176+
177+
std::string in_str = "(";
178+
pre_str = "";
179+
for (const auto& input : sorted_inputs) {
180+
in_str.append(pre_str + input->Name());
181+
pre_str = ", ";
166182
}
183+
in_str.append(")");
184+
op_str = out_str + op_str + in_str;
167185
} else {
168186
// A normal Op, has OpDesc, create from ProgramDesc
169-
for (const auto& input : op->InputNames()) {
170-
op_str.append(input);
171-
for (const auto& arg : op->Input(input)) {
172-
op_str.append(arg);
187+
std::string out_str = "{";
188+
std::string outer_pre_str = "";
189+
for (const auto& output : op->OutputNames()) {
190+
out_str.append(outer_pre_str + output + "=[");
191+
std::string inner_pre_str = "";
192+
for (const auto& arg : op->Output(output)) {
193+
out_str.append(inner_pre_str + arg);
194+
inner_pre_str = " ,";
173195
}
196+
outer_pre_str = ", ";
197+
out_str.append("]");
174198
}
199+
out_str.append("} = ");
175200

176-
for (const auto& output : op->OutputNames()) {
177-
op_str.append(output);
178-
for (const auto& arg : op->Output(output)) {
179-
op_str.append(arg);
201+
std::string in_str = "(";
202+
outer_pre_str = "";
203+
for (const auto& input : op->InputNames()) {
204+
in_str.append(outer_pre_str + input + "=[");
205+
std::string inner_pre_str = "";
206+
for (const auto& arg : op->Input(input)) {
207+
in_str.append(inner_pre_str + arg);
208+
inner_pre_str = " ,";
180209
}
210+
outer_pre_str = " ,";
211+
in_str.append("]");
181212
}
213+
in_str.append(")");
214+
op_str = out_str + op_str + in_str;
182215
}
183216

184217
return op_str;
@@ -203,6 +236,7 @@ class Node {
203236
int id_;
204237

205238
int desc_order_;
239+
int block_id_{-1};
206240

207241
private:
208242
// ID can only set by a Graph.
@@ -218,19 +252,21 @@ class Node {
218252
friend std::unique_ptr<Node> CreateNodeForTest(VarDesc* var_desc);
219253
friend std::unique_ptr<Node> CreateNodeForTest(OpDesc* op_desc);
220254

221-
explicit Node(const std::string& name, Type type)
255+
explicit Node(const std::string& name, Type type, int block_id = 0)
222256
: name_(name),
223257
var_desc_(nullptr),
224258
op_desc_(nullptr),
225259
type_(type),
226-
desc_order_(NO_DESC_ORDER) {}
260+
desc_order_(NO_DESC_ORDER),
261+
block_id_(block_id) {}
227262

228-
explicit Node(VarDesc* var_desc)
263+
explicit Node(VarDesc* var_desc, int block_id)
229264
: name_(var_desc->Name()),
230265
var_desc_(new VarDesc(*var_desc)),
231266
op_desc_(nullptr),
232267
type_(Type::kVariable),
233-
desc_order_(NO_DESC_ORDER) {}
268+
desc_order_(NO_DESC_ORDER),
269+
block_id_(block_id) {}
234270

235271
explicit Node(OpDesc* op_desc)
236272
: name_(op_desc->Type()),

paddle/fluid/framework/ir/node_test.cc

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ limitations under the License. */
1414

1515
#include "paddle/fluid/framework/ir/node.h"
1616
#include "gtest/gtest.h"
17+
#include "paddle/fluid/framework/var_desc.h"
1718

1819
namespace paddle {
1920
namespace framework {
@@ -76,24 +77,29 @@ TEST(NodeTest, Basic) {
7677
}
7778

7879
TEST(NodeTest, ToString) {
80+
VarDesc var_desc("n2");
81+
OpDesc op_desc;
82+
op_desc.SetType("test_op");
83+
op_desc.SetInput("X", {"x1", "x2", "x3"});
84+
op_desc.SetOutput("Y", {"y1", "y2"});
85+
7986
std::unique_ptr<Node> n1(CreateNodeForTest("n1", Node::Type::kVariable));
87+
std::unique_ptr<Node> n2(CreateNodeForTest(&var_desc));
88+
std::unique_ptr<Node> n3(CreateNodeForTest("n3", Node::Type::kOperation));
89+
std::unique_ptr<Node> n4(CreateNodeForTest(&op_desc));
90+
8091
EXPECT_EQ(n1->ToString(), "n1");
92+
EXPECT_EQ(n2->ToString(), "n2");
8193

82-
std::unique_ptr<Node> op1(CreateNodeForTest("op1", Node::Type::kOperation));
94+
EXPECT_EQ(n3->Op(), nullptr);
95+
EXPECT_EQ(n3->ToString(), "{} = n3()");
96+
EXPECT_NE(n4->Op(), nullptr);
97+
EXPECT_EQ(n4->ToString(), "{Y=[y1 ,y2]} = test_op(X=[x1 ,x2 ,x3])");
8398

84-
std::unique_ptr<Node> n2(CreateNodeForTest("n2", Node::Type::kVariable));
85-
std::unique_ptr<Node> n3(CreateNodeForTest("n3", Node::Type::kVariable));
86-
87-
op1->inputs.emplace_back(n2.get());
88-
op1->outputs.emplace_back(n3.get());
89-
EXPECT_EQ(op1->ToString(), "op1n2n3");
90-
91-
OpDesc desc;
92-
desc.SetType("op2");
93-
desc.SetInput("X", {"arg1"});
94-
desc.SetOutput("Out", {"res1"});
95-
std::unique_ptr<Node> op2(CreateNodeForTest(&desc));
96-
EXPECT_EQ(op2->ToString(), "op2Xarg1Outres1");
99+
n3->inputs.push_back(n1.get());
100+
n3->outputs.push_back(n2.get());
101+
EXPECT_EQ(n3->Op(), nullptr);
102+
EXPECT_EQ(n3->ToString(), "{n2} = n3(n1)");
97103
}
98104

99105
} // namespace ir

0 commit comments

Comments
 (0)