@@ -82,7 +82,7 @@ TEST(ProgramDesc, GetInputsOutputsInBlock) {
8282 while_op->SetType (" while" );
8383 while_op->SetAttr (" sub_block" , sub_blocks[0 ]);
8484
85- auto * while_x = sub_blocks[ 0 ] ->Var (" While_X" );
85+ auto * while_x = global_block ->Var (" While_X" );
8686 while_x->SetType (proto::VarType::LOD_TENSOR);
8787 while_x->SetLoDLevel (0 );
8888 while_x->SetDataType (proto::VarType::FP32);
@@ -91,13 +91,13 @@ TEST(ProgramDesc, GetInputsOutputsInBlock) {
9191 while_op->SetInput (" kX" , {while_x->Name ()});
9292 while_op->SetInput (" kCondition" , {less_than_1_out->Name ()});
9393
94- auto * while_out = sub_blocks[ 0 ] ->Var (" While_Out" );
94+ auto * while_out = global_block ->Var (" While_Out" );
9595 while_out->SetType (proto::VarType::LOD_TENSOR);
9696 while_out->SetLoDLevel (0 );
9797 while_out->SetDataType (proto::VarType::FP32);
9898 while_out->SetShape ({1 });
9999
100- auto * steps = sub_blocks[ 0 ] ->Var (" StepScopes" );
100+ auto * steps = global_block ->Var (" StepScopes" );
101101
102102 while_op->SetOutput (" kOutputs" , {while_out->Name ()});
103103 while_op->SetOutput (" kStepScopes" , {steps->Name ()});
@@ -148,7 +148,7 @@ TEST(ProgramDesc, GetInputsOutputsInBlock) {
148148 cond_op->SetType (" conditional_block" );
149149 cond_op->SetAttr (" sub_block" , sub_blocks[1 ]);
150150
151- auto * cond_x = global_block ->Var (" Cond_X" );
151+ auto * cond_x = sub_blocks[ 0 ] ->Var (" Cond_X" );
152152 cond_x->SetType (proto::VarType::LOD_TENSOR);
153153 cond_x->SetLoDLevel (0 );
154154 cond_x->SetDataType (proto::VarType::FP32);
@@ -157,13 +157,13 @@ TEST(ProgramDesc, GetInputsOutputsInBlock) {
157157 cond_op->SetInput (" kInputs" , {cond_x->Name ()});
158158 cond_op->SetInput (" kCondition" , {less_than_2_out->Name ()});
159159
160- auto * cond_out = global_block ->Var (" Out5 " );
160+ auto * cond_out = sub_blocks[ 0 ] ->Var (" Cond_Out " );
161161 cond_out->SetType (proto::VarType::LOD_TENSOR);
162162 cond_out->SetLoDLevel (0 );
163163 cond_out->SetDataType (proto::VarType::FP32);
164164 cond_out->SetShape ({1 });
165165
166- auto * scope = global_block ->Var (" Scope" );
166+ auto * scope = sub_blocks[ 0 ] ->Var (" Scope" );
167167 scope->SetType (proto::VarType::STEP_SCOPES);
168168
169169 cond_op->SetOutput (" kOutputs" , {cond_out->Name ()});
@@ -200,8 +200,8 @@ TEST(ProgramDesc, GetInputsOutputsInBlock) {
200200 VLOG (3 ) << " inner_inputs().size():" << inner_inputs.size ();
201201 VLOG (3 ) << " inner_outputs().size():" << inner_outputs.size ();
202202
203- ASSERT_EQ (5UL , inner_inputs.size ());
204- ASSERT_EQ (4UL , inner_outputs.size ());
203+ ASSERT_EQ (4UL , inner_inputs.size ());
204+ ASSERT_EQ (2UL , inner_outputs.size ());
205205
206206 VLOG (3 ) << " Before AddDependency, while op's input kX size:"
207207 << while_op->Input (" kX" ).size ();
@@ -214,8 +214,16 @@ TEST(ProgramDesc, GetInputsOutputsInBlock) {
214214 VLOG (3 ) << " After AddDependency, while op's output kOutPuts size:"
215215 << while_op->Output (" kOutputs" ).size ();
216216
217- ASSERT_EQ (8UL , while_op->Input (" kX" ).size ());
218- ASSERT_EQ (6UL , while_op->Output (" kOutputs" ).size ());
217+ ASSERT_EQ (7UL , while_op->Input (" kX" ).size ());
218+ ASSERT_EQ (4UL , while_op->Output (" kOutputs" ).size ());
219+
220+ // auto var_input_vec = {"Less_than_1_Out", "While_X", "Less_than_2_X",
221+ // "Less_than_2_Y", "Less_than_2_out", "Mul_3_X", "Mul_3_Y"};
222+ // auto var_output_vec = {"While_Out", "Mul_2_Out", "Less_than_2_out",
223+ // "Mul_3_Out"};
224+
225+ // ASSERT_EQ(var_input_vec, while_op->Input("kX"));
226+ // ASSERT_EQ(var_output_vec, while_op->Output("kOutputs"));
219227}
220228} // namespace framework
221229} // namespace paddle
0 commit comments