From 92b42f3dc15c51eca4dd890e8d1139b0ff00a86f Mon Sep 17 00:00:00 2001 From: huangjiyi <947613776@qq.com> Date: Wed, 7 Feb 2024 16:31:21 +0800 Subject: [PATCH] fix_test_write_python_container --- .../control_flow/while_instruction.cc | 22 +++++++++++-------- .../test_write_python_container.py | 1 + 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/framework/new_executor/instruction/control_flow/while_instruction.cc b/paddle/fluid/framework/new_executor/instruction/control_flow/while_instruction.cc index 54e77262788b53..c30b5561810f16 100644 --- a/paddle/fluid/framework/new_executor/instruction/control_flow/while_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/control_flow/while_instruction.cc @@ -205,20 +205,11 @@ void WhileInstruction::ShareDatasToOutputs() { auto& out_var_name = body_outputs_[i + 1]; auto* out_var = body_inter_->local_scope()->GetVar(out_var_name); VLOG(6) << "share data from " << out_var_name << " -> " << i << " output"; - if (out_var->IsType()) { outputs_[i]->GetMutable()->ShareDataWith( out_var->Get()); VLOG(6) << "share data from " << out_var_name << "[" << out_var << "]" << " -> " << i << " output[" << outputs_[i] << "]"; - - // NOTE(zhangbo): Delete the input of the yield operator, except for the - // external vars of the block. - if (external_input_names_.count(out_var_name) == 0) { - VLOG(6) << "clear internel input " << out_var_name; - out_var->GetMutable()->clear(); - } - } else if (out_var->IsType()) { const auto& inner_array = out_var->Get(); auto* output_array = outputs_[i]->GetMutable(); @@ -230,6 +221,19 @@ void WhileInstruction::ShareDatasToOutputs() { VLOG(6) << "done"; } + + for (size_t i = 0; i < outputs_.size(); ++i) { + auto& out_var_name = body_outputs_[i + 1]; + auto* out_var = body_inter_->local_scope()->GetVar(out_var_name); + if (out_var->IsType()) { + // NOTE(zhangbo): Delete the input of the yield operator, except for the + // external vars of the block. + if (external_input_names_.count(out_var_name) == 0) { + VLOG(6) << "clear internel input " << out_var_name; + out_var->GetMutable()->clear(); + } + } + } } void WhileInstruction::Run() { diff --git a/test/dygraph_to_static/test_write_python_container.py b/test/dygraph_to_static/test_write_python_container.py index c7960e1cc87ca2..277372adca8cd5 100644 --- a/test/dygraph_to_static/test_write_python_container.py +++ b/test/dygraph_to_static/test_write_python_container.py @@ -127,6 +127,7 @@ def test_write_container_sot(self): self.assertEqual(out_static, out_dygraph) @test_ast_only + @test_legacy_and_pt_and_pir def test_write_container(self): func_static = paddle.jit.to_static(self.func) input = paddle.to_tensor([1, 2, 3])