Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions paddle/fluid/eager/to_static/run_program_op_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@
#include "paddle/fluid/eager/eager_tensor.h"
#include "paddle/fluid/eager/to_static/run_program_op_node.h"
#include "paddle/fluid/eager/utils.h"
#include "paddle/fluid/framework/tensor_ref_array.h"
#include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/pir/include/core/block.h"
#include "paddle/pir/include/core/builtin_type.h"
#include "paddle/pir/include/core/value.h"
#include "paddle/pir/include/dialect/control_flow/ir/cf_type.h"

// Filter params without grads in global block. In this case, we will
// tag its AutogradMeta with stop_gradient = True to avoid fault from
Expand Down Expand Up @@ -244,8 +247,9 @@ inline void pir_run_program_ad_func(
trace_backward, &p_autograd_x, &p_autograd_params);

// Create Middle Output for GradNode.
auto middle_size =
PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("fm")).size();
auto middle_values =
PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("fm"));
auto middle_size = middle_values.size();
auto output_size =
PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("fo")).size();
auto middles = std::vector<paddle::Tensor*>();
Expand All @@ -264,8 +268,14 @@ inline void pir_run_program_ad_func(
grad_node->GetMiddle().resize(middle_size);
grad_node->GetOutputs().resize(output_size);
for (size_t i = 0; i < middle_size; ++i) {
grad_node->GetMiddle()[i] =
paddle::Tensor(std::make_shared<phi::DenseTensor>());
auto middle_value = middle_values[i];
if (middle_value.type().isa<pir::DenseTensorType>()) {
grad_node->GetMiddle()[i] =
paddle::Tensor(std::make_shared<phi::DenseTensor>());
} else if (middle_value.type().isa<pir::OutletType>()) {
grad_node->GetMiddle()[i] = paddle::Tensor(
std::make_shared<paddle::framework::VariableRefArray>());
}
middles.push_back(&grad_node->GetMiddle()[i]);
}

Expand Down
24 changes: 23 additions & 1 deletion paddle/fluid/eager/to_static/run_program_op_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "paddle/fluid/eager/tensor_wrapper.h"
#include "paddle/fluid/framework/executor_cache.h"
#include "paddle/fluid/framework/new_executor/interpretercore.h"
#include "paddle/fluid/framework/tensor_ref_array.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/ir_adaptor/translator/program_translator.h"
#include "paddle/fluid/operators/run_program_op.h"
Expand Down Expand Up @@ -120,10 +121,20 @@ static void CheckOutputVarStatus(const paddle::framework::Variable &src_var,
"RunProgram(Grad)Op's internal scope holds "
"wrong type. Expect type is SelectedRows",
name));
} else if (paddle::framework::VariableRefArray::classof(
dst_tensor.impl().get())) {
auto &src_tensor = src_var.Get<paddle::framework::VariableRefArray>();
PADDLE_ENFORCE_EQ(paddle::framework::VariableRefArray::classof(&src_tensor),
true,
paddle::platform::errors::InvalidArgument(
"The output tensor %s get from "
"RunProgram(Grad)Op's internal scope holds "
"wrong type. Expect type is VariableRefArray",
name));
} else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"The RunProgram(Grad)Op only support output "
"variable of type LoDTensor or SelectedRows",
"variable of type DenseTensor, SelectedRows or VariableRefArray",
name));
}
}
Expand Down Expand Up @@ -320,6 +331,17 @@ static void ShareTensorsFromScopeByValue(
auto *dst_tensor = const_cast<phi::SelectedRows *>(
dynamic_cast<const phi::SelectedRows *>(tensors[i]->impl().get()));
*dst_tensor = src_tensor;
} else if (var->IsType<paddle::framework::VariableRefArray>()) {
auto &src_tensor = var->Get<paddle::framework::VariableRefArray>();
auto *dst_tensor = const_cast<paddle::framework::VariableRefArray *>(
dynamic_cast<const paddle::framework::VariableRefArray *>(
tensors[i]->impl().get()));
*dst_tensor = src_tensor;
} else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"The RunProgram(Grad)Op only support output "
"variable of type DenseTensor, SelectedRows or VariableRefArray",
name));
}
}
}
Expand Down
13 changes: 8 additions & 5 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -970,11 +970,14 @@ AnalysisMiddleVariable(const Program &program,
program.block(),
forward_range,
[&middle_values, &backward_inputs, &x_or_param](Operation *op) {
for (auto &t : op->results()) {
auto v = Value(t.Value::impl());
if (backward_inputs.count(v) && !x_or_param.count(v))
middle_values.push_back(v);
}
pir::Walk(op, [&](Operation *inner_op) {
for (auto &t : inner_op->results()) {
auto v = Value(t.Value::impl());
if (backward_inputs.count(v) && !x_or_param.count(v)) {
middle_values.push_back(v);
}
}
});
});
return std::make_pair(middle_values, backward_inputs);
}
Expand Down
3 changes: 1 addition & 2 deletions test/dygraph_to_static/test_ifelse.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
enable_to_static_guard,
test_ast_only,
test_legacy_and_pt_and_pir,
test_legacy_only,
test_pir_only,
)
from ifelse_simple_func import (
Expand Down Expand Up @@ -338,7 +337,7 @@ def _run(self, to_static=False):
ret = net(x_v)
return ret.numpy()

@test_legacy_only
@test_legacy_and_pt_and_pir
def test_ast_to_func(self):
self.assertTrue((self._run_dygraph() == self._run_static()).all())

Expand Down