diff --git a/paddle/fluid/eager/to_static/run_program_op_func.h b/paddle/fluid/eager/to_static/run_program_op_func.h index f6b8e21cd8b17f..c767ad0b6106c7 100644 --- a/paddle/fluid/eager/to_static/run_program_op_func.h +++ b/paddle/fluid/eager/to_static/run_program_op_func.h @@ -20,9 +20,12 @@ #include "paddle/fluid/eager/eager_tensor.h" #include "paddle/fluid/eager/to_static/run_program_op_node.h" #include "paddle/fluid/eager/utils.h" +#include "paddle/fluid/framework/tensor_ref_array.h" #include "paddle/fluid/memory/allocation/allocator.h" #include "paddle/pir/include/core/block.h" +#include "paddle/pir/include/core/builtin_type.h" #include "paddle/pir/include/core/value.h" +#include "paddle/pir/include/dialect/control_flow/ir/cf_type.h" // Filter params without grads in global block. In this case, we will // tag its AutogradMeta with stop_gradient = True to avoid fault from @@ -244,8 +247,9 @@ inline void pir_run_program_ad_func( trace_backward, &p_autograd_x, &p_autograd_params); // Create Middle Output for GradNode. - auto middle_size = - PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("fm")).size(); + auto middle_values = + PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("fm")); + auto middle_size = middle_values.size(); auto output_size = PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("fo")).size(); auto middles = std::vector(); @@ -264,8 +268,14 @@ inline void pir_run_program_ad_func( grad_node->GetMiddle().resize(middle_size); grad_node->GetOutputs().resize(output_size); for (size_t i = 0; i < middle_size; ++i) { - grad_node->GetMiddle()[i] = - paddle::Tensor(std::make_shared()); + auto middle_value = middle_values[i]; + if (middle_value.type().isa()) { + grad_node->GetMiddle()[i] = + paddle::Tensor(std::make_shared()); + } else if (middle_value.type().isa()) { + grad_node->GetMiddle()[i] = paddle::Tensor( + std::make_shared()); + } middles.push_back(&grad_node->GetMiddle()[i]); } diff --git a/paddle/fluid/eager/to_static/run_program_op_node.h b/paddle/fluid/eager/to_static/run_program_op_node.h index fdebfbb1e3771f..da04f129c01aa7 100644 --- a/paddle/fluid/eager/to_static/run_program_op_node.h +++ b/paddle/fluid/eager/to_static/run_program_op_node.h @@ -19,6 +19,7 @@ #include "paddle/fluid/eager/tensor_wrapper.h" #include "paddle/fluid/framework/executor_cache.h" #include "paddle/fluid/framework/new_executor/interpretercore.h" +#include "paddle/fluid/framework/tensor_ref_array.h" #include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/ir_adaptor/translator/program_translator.h" #include "paddle/fluid/operators/run_program_op.h" @@ -120,10 +121,20 @@ static void CheckOutputVarStatus(const paddle::framework::Variable &src_var, "RunProgram(Grad)Op's internal scope holds " "wrong type. Expect type is SelectedRows", name)); + } else if (paddle::framework::VariableRefArray::classof( + dst_tensor.impl().get())) { + auto &src_tensor = src_var.Get(); + PADDLE_ENFORCE_EQ(paddle::framework::VariableRefArray::classof(&src_tensor), + true, + paddle::platform::errors::InvalidArgument( + "The output tensor %s get from " + "RunProgram(Grad)Op's internal scope holds " + "wrong type. Expect type is VariableRefArray", + name)); } else { PADDLE_THROW(paddle::platform::errors::InvalidArgument( "The RunProgram(Grad)Op only support output " - "variable of type LoDTensor or SelectedRows", + "variable of type DenseTensor, SelectedRows or VariableRefArray", name)); } } @@ -320,6 +331,17 @@ static void ShareTensorsFromScopeByValue( auto *dst_tensor = const_cast( dynamic_cast(tensors[i]->impl().get())); *dst_tensor = src_tensor; + } else if (var->IsType()) { + auto &src_tensor = var->Get(); + auto *dst_tensor = const_cast( + dynamic_cast( + tensors[i]->impl().get())); + *dst_tensor = src_tensor; + } else { + PADDLE_THROW(paddle::platform::errors::InvalidArgument( + "The RunProgram(Grad)Op only support output " + "variable of type DenseTensor, SelectedRows or VariableRefArray", + name)); } } } diff --git a/paddle/fluid/pybind/pir.cc b/paddle/fluid/pybind/pir.cc index 35c19c8f00c76e..a24afc3585a717 100644 --- a/paddle/fluid/pybind/pir.cc +++ b/paddle/fluid/pybind/pir.cc @@ -970,11 +970,14 @@ AnalysisMiddleVariable(const Program &program, program.block(), forward_range, [&middle_values, &backward_inputs, &x_or_param](Operation *op) { - for (auto &t : op->results()) { - auto v = Value(t.Value::impl()); - if (backward_inputs.count(v) && !x_or_param.count(v)) - middle_values.push_back(v); - } + pir::Walk(op, [&](Operation *inner_op) { + for (auto &t : inner_op->results()) { + auto v = Value(t.Value::impl()); + if (backward_inputs.count(v) && !x_or_param.count(v)) { + middle_values.push_back(v); + } + } + }); }); return std::make_pair(middle_values, backward_inputs); } diff --git a/test/dygraph_to_static/test_ifelse.py b/test/dygraph_to_static/test_ifelse.py index a05f3d07510e9c..fef4c48d495125 100644 --- a/test/dygraph_to_static/test_ifelse.py +++ b/test/dygraph_to_static/test_ifelse.py @@ -23,7 +23,6 @@ enable_to_static_guard, test_ast_only, test_legacy_and_pt_and_pir, - test_legacy_only, test_pir_only, ) from ifelse_simple_func import ( @@ -338,7 +337,7 @@ def _run(self, to_static=False): ret = net(x_v) return ret.numpy() - @test_legacy_only + @test_legacy_and_pt_and_pir def test_ast_to_func(self): self.assertTrue((self._run_dygraph() == self._run_static()).all())