Skip to content

Commit f72d000

Browse files
committed
Add cond op in unittest.
1 parent 7426b29 commit f72d000

File tree

3 files changed

+35
-5
lines changed

3 files changed

+35
-5
lines changed

paddle/fluid/framework/program_processing.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -29,12 +29,12 @@ bool ProgramProcessor::IsControlFlowBlock(ProgramDesc *program,
2929
inner_inputs.end())
3030
inner_inputs.push_back(iname);
3131
}
32-
BlockDesc parent_block = program->Block(current_block.Parent());
3332
for (auto in_var_name : inner_inputs) {
34-
VarDesc *parent_block_var = parent_block.FindVarRecursive(in_var_name);
33+
VarDesc *parent_block_var =
34+
program->Block(current_block.Parent()).FindVarRecursive(in_var_name);
3535
VarDesc *current_block_var;
3636
if (current_block.HasVar(in_var_name)) {
37-
current_block_var = current_block.Var(in_var_name);
37+
current_block_var = current_block.FindVar(in_var_name);
3838
}
3939
if (parent_block_var == nullptr && current_block_var)
4040
removed_inner_inputs.push_back(in_var_name);
@@ -51,6 +51,8 @@ void ProgramProcessor::SSAProgram(ProgramDesc *program) {
5151
if (IsControlFlowBlock(program, program->Block(i))) {
5252
VLOG(3) << "Block ID with whlie op:" << program->Block(i).ID();
5353
// ssa_processing(program, cur_block);
54+
} else {
55+
VLOG(3) << "Not a ControlFlow Block";
5456
}
5557
}
5658
}

paddle/fluid/framework/program_processing.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class ProgramProcessor {
3232

3333
// explicit ProgramProcessor(const ProgramDesc &program);
3434

35-
bool IsControlFlowBlock(ProgramDesc *program, BlockDesc *current_block);
35+
bool IsControlFlowBlock(ProgramDesc *program, const BlockDesc &current_block);
3636

3737
void SSAProgram(ProgramDesc *program);
3838
};

paddle/fluid/framework/program_processing_test.cc

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,34 @@ TEST(ProgramDesc, SSAprogram) {
6161
VLOG(3) << "sub_blocks_Parent:" << sub_blocks[0]->Parent();
6262
op->SetAttr("sub_blocks", sub_blocks);
6363

64+
// building cond op such as less_than
65+
BlockDesc* parent_block = program.MutableBlock(sub_blocks[0]->Parent());
66+
op = parent_block->AppendOp();
67+
op->SetType("less_than");
68+
auto* x = parent_block->Var("X");
69+
x->SetType(proto::VarType::LOD_TENSOR);
70+
x->SetLoDLevel(0);
71+
x->SetDataType(proto::VarType::FP32);
72+
x->SetShape({1});
73+
74+
auto* y = parent_block->Var("Y");
75+
y->SetType(proto::VarType::LOD_TENSOR);
76+
y->SetLoDLevel(0);
77+
y->SetDataType(proto::VarType::FP32);
78+
y->SetShape({1});
79+
80+
op->SetInput("X", {x->Name()});
81+
op->SetInput("Y", {y->Name()});
82+
83+
auto* out = parent_block->Var("Out");
84+
out->SetType(proto::VarType::BOOL);
85+
op->SetOutput("Out", {out->Name()});
86+
87+
// building while op
88+
// BlockDesc* parent_block = program.MutableBlock(sub_blocks[0]->Parent());
89+
op = parent_block->AppendOp();
90+
op->SetType("while");
91+
6492
ProgramProcessor program_processor;
6593
program_processor.SSAProgram(&program);
6694
}

0 commit comments

Comments
 (0)