Skip to content

Commit c1e113a

Browse files
committed
add some warning code and fix some wrong logic.
1 parent 5adc9c3 commit c1e113a

File tree

2 files changed

+41
-31
lines changed

2 files changed

+41
-31
lines changed

paddle/fluid/framework/program_processing.cc

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,6 @@ void ProgramProcessor::GetInputsOutputsInBlock(
2929
// Step1: update inner_inputs and inner_outputs
3030
// NOTE: Here assumes that all variables are input or output of Ops,
3131

32-
std::set<std::string> removed_inner_inputs;
33-
3432
for (OpDesc *op : current_block.AllOps()) {
3533
for (auto iname : op->InputNames()) {
3634
for (auto in_var_name : op->Input(iname)) {
@@ -52,26 +50,27 @@ void ProgramProcessor::GetInputsOutputsInBlock(
5250

5351
// Step2: Remove variable created in current control flow block.
5452
BlockDesc *parent_block = current_block.ParentBlock();
55-
VarDesc *current_block_var;
5653

5754
if (parent_block) {
58-
for (auto in_var_name : *inner_inputs) {
59-
VLOG(3) << "recursively find var:" << in_var_name;
60-
VarDesc *parent_block_var = parent_block->FindVarRecursive(in_var_name);
61-
if (current_block.HasVar(in_var_name))
62-
current_block_var = current_block.FindVar(in_var_name);
63-
if (parent_block_var == nullptr && current_block_var &&
64-
current_block_var->GetType() == proto::VarType::LOD_TENSOR) {
65-
VLOG(3) << "remove inner var:" << in_var_name;
66-
removed_inner_inputs.insert(in_var_name);
55+
for (auto iter = inner_inputs->begin(); iter != inner_inputs->end();) {
56+
const std::string &in_var_name = *iter;
57+
if (current_block.HasVar(in_var_name)) {
58+
VLOG(3) << "remove inner intput var:" << in_var_name;
59+
iter = inner_inputs->erase(iter);
60+
} else {
61+
++iter;
62+
}
63+
}
64+
65+
for (auto iter = inner_outputs->begin(); iter != inner_outputs->end();) {
66+
const std::string &out_var_name = *iter;
67+
if (current_block.HasVar(out_var_name)) {
68+
VLOG(3) << "remove inner output var:" << out_var_name;
69+
iter = inner_outputs->erase(iter);
70+
} else {
71+
++iter;
6772
}
6873
}
69-
std::set<std::string> inner_inputs_;
70-
std::set_difference(inner_inputs->begin(), inner_inputs->end(),
71-
removed_inner_inputs.begin(),
72-
removed_inner_inputs.end(),
73-
inserter(inner_inputs_, inner_inputs_.begin()));
74-
inner_inputs->swap(inner_inputs_);
7574
}
7675
}
7776

@@ -95,13 +94,16 @@ void ProgramProcessor::AddDepToBlockOp(const BlockDesc &block) {
9594
auto *op_inputs = op->MutableInputs();
9695
std::vector<std::string> *op_input_var_vec;
9796
VLOG(3) << "op_type:>>>>>>" << op_type;
98-
if (op_type.compare("while") == 0)
97+
if (op_type.compare("while") == 0) {
9998
op_input_var_vec = &((*op_inputs)["kX"]);
100-
else if (op_type.compare("conditional_block") == 0)
99+
} else if (op_type.compare("conditional_block") == 0) {
101100
op_input_var_vec = &((*op_inputs)["kInputs"]);
102-
else
101+
} else {
103102
// Only support while_op and conditinal_block_op now
103+
throw std::invalid_argument(
104+
"Currently, only support while_op and conditinal_block_op.\n");
104105
continue;
106+
}
105107

106108
for (auto sub_input : sub_inputs) {
107109
if (std::find(op_input_var_vec->begin(), op_input_var_vec->end(),

paddle/fluid/framework/program_processing_test.cc

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ TEST(ProgramDesc, GetInputsOutputsInBlock) {
8282
while_op->SetType("while");
8383
while_op->SetAttr("sub_block", sub_blocks[0]);
8484

85-
auto* while_x = sub_blocks[0]->Var("While_X");
85+
auto* while_x = global_block->Var("While_X");
8686
while_x->SetType(proto::VarType::LOD_TENSOR);
8787
while_x->SetLoDLevel(0);
8888
while_x->SetDataType(proto::VarType::FP32);
@@ -91,13 +91,13 @@ TEST(ProgramDesc, GetInputsOutputsInBlock) {
9191
while_op->SetInput("kX", {while_x->Name()});
9292
while_op->SetInput("kCondition", {less_than_1_out->Name()});
9393

94-
auto* while_out = sub_blocks[0]->Var("While_Out");
94+
auto* while_out = global_block->Var("While_Out");
9595
while_out->SetType(proto::VarType::LOD_TENSOR);
9696
while_out->SetLoDLevel(0);
9797
while_out->SetDataType(proto::VarType::FP32);
9898
while_out->SetShape({1});
9999

100-
auto* steps = sub_blocks[0]->Var("StepScopes");
100+
auto* steps = global_block->Var("StepScopes");
101101

102102
while_op->SetOutput("kOutputs", {while_out->Name()});
103103
while_op->SetOutput("kStepScopes", {steps->Name()});
@@ -148,7 +148,7 @@ TEST(ProgramDesc, GetInputsOutputsInBlock) {
148148
cond_op->SetType("conditional_block");
149149
cond_op->SetAttr("sub_block", sub_blocks[1]);
150150

151-
auto* cond_x = global_block->Var("Cond_X");
151+
auto* cond_x = sub_blocks[0]->Var("Cond_X");
152152
cond_x->SetType(proto::VarType::LOD_TENSOR);
153153
cond_x->SetLoDLevel(0);
154154
cond_x->SetDataType(proto::VarType::FP32);
@@ -157,13 +157,13 @@ TEST(ProgramDesc, GetInputsOutputsInBlock) {
157157
cond_op->SetInput("kInputs", {cond_x->Name()});
158158
cond_op->SetInput("kCondition", {less_than_2_out->Name()});
159159

160-
auto* cond_out = global_block->Var("Out5");
160+
auto* cond_out = sub_blocks[0]->Var("Cond_Out");
161161
cond_out->SetType(proto::VarType::LOD_TENSOR);
162162
cond_out->SetLoDLevel(0);
163163
cond_out->SetDataType(proto::VarType::FP32);
164164
cond_out->SetShape({1});
165165

166-
auto* scope = global_block->Var("Scope");
166+
auto* scope = sub_blocks[0]->Var("Scope");
167167
scope->SetType(proto::VarType::STEP_SCOPES);
168168

169169
cond_op->SetOutput("kOutputs", {cond_out->Name()});
@@ -200,8 +200,8 @@ TEST(ProgramDesc, GetInputsOutputsInBlock) {
200200
VLOG(3) << "inner_inputs().size():" << inner_inputs.size();
201201
VLOG(3) << "inner_outputs().size():" << inner_outputs.size();
202202

203-
ASSERT_EQ(5UL, inner_inputs.size());
204-
ASSERT_EQ(4UL, inner_outputs.size());
203+
ASSERT_EQ(4UL, inner_inputs.size());
204+
ASSERT_EQ(2UL, inner_outputs.size());
205205

206206
VLOG(3) << "Before AddDependency, while op's input kX size:"
207207
<< while_op->Input("kX").size();
@@ -214,8 +214,16 @@ TEST(ProgramDesc, GetInputsOutputsInBlock) {
214214
VLOG(3) << "After AddDependency, while op's output kOutPuts size:"
215215
<< while_op->Output("kOutputs").size();
216216

217-
ASSERT_EQ(8UL, while_op->Input("kX").size());
218-
ASSERT_EQ(6UL, while_op->Output("kOutputs").size());
217+
ASSERT_EQ(7UL, while_op->Input("kX").size());
218+
ASSERT_EQ(4UL, while_op->Output("kOutputs").size());
219+
220+
// auto var_input_vec = {"Less_than_1_Out", "While_X", "Less_than_2_X",
221+
// "Less_than_2_Y", "Less_than_2_out", "Mul_3_X", "Mul_3_Y"};
222+
// auto var_output_vec = {"While_Out", "Mul_2_Out", "Less_than_2_out",
223+
// "Mul_3_Out"};
224+
225+
// ASSERT_EQ(var_input_vec, while_op->Input("kX"));
226+
// ASSERT_EQ(var_output_vec, while_op->Output("kOutputs"));
219227
}
220228
} // namespace framework
221229
} // namespace paddle

0 commit comments

Comments
 (0)