Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 30 additions & 24 deletions paddle/fluid/eager/to_static/run_program_op_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,39 +91,45 @@ static bool IsVariableRefArray(const Tensor &tensor) {

static auto GetNameFromValue(const ::pir::Block *block,
const std::vector<::pir::Value> &values,
bool is_input) {
bool allow_input,
bool allow_output) {
PADDLE_ENFORCE_EQ(
allow_input || allow_output,
true,
paddle::platform::errors::InvalidArgument(
"GetNameFromValue should allow input or output at least one."));
// we use name here, later value is used directly.
std::unordered_map<::pir::Value, std::string> value2name;
if (is_input) {
if (allow_input) {
for (auto &kwarg : block->kwargs()) {
value2name[kwarg.second] = kwarg.first;
}
}
for (auto &op : *block) {
std::string name;
if (is_input && op.name() == "pd_op.data") {
if (allow_input && op.name() == "pd_op.data") {
name =
op.attributes().at("name").dyn_cast<pir::StrAttribute>().AsString();
value2name[op.results()[0].Value::impl()] = name;
} else if (!is_input && op.name() == "builtin.set_parameter") {
} else if (allow_output && op.name() == "builtin.set_parameter") {
name = op.attributes()
.at("parameter_name")
.dyn_cast<pir::StrAttribute>()
.AsString();
value2name[op.operand(0).source()] = name;
} else if (!is_input && op.name() == "builtin.shadow_output") {
} else if (allow_output && op.name() == "builtin.shadow_output") {
name = op.attributes()
.at("output_name")
.dyn_cast<pir::StrAttribute>()
.AsString();
value2name[op.operand(0).source()] = name;
} else if (is_input && op.name() == "builtin.parameter") {
} else if (allow_input && op.name() == "builtin.parameter") {
name = op.attributes()
.at("parameter_name")
.dyn_cast<pir::StrAttribute>()
.AsString();
value2name[op.result(0).Value::impl()] = name;
} else if (is_input && op.name() == "builtin.constant") {
} else if (allow_input && op.name() == "builtin.constant") {
if (op.isa<pir::ConstantTensorOp>()) {
name = op.dyn_cast<pir::ConstantTensorOp>().tensor_name();
value2name[op.result(0).Value::impl()] = name;
Expand Down Expand Up @@ -248,12 +254,7 @@ static void ShareTensorsIntoScopeByValue(
const std::vector<Tensor> &tensors,
const std::vector<::pir::Value> &values,
paddle::framework::Scope *scope) {
auto names = GetNameFromValue(block, values, true);
if (VLOG_IS_ON(4)) {
for (auto &s : names) {
VLOG(4) << "ShareTensorIntoScopeByValue name: " << s;
}
}
auto names = GetNameFromValue(block, values, true, false);
ShareTensorsIntoScopeWithName(tensors, names, scope);
}

Expand All @@ -262,11 +263,16 @@ static void ShareTensorsFromScopeByValue(
const std::vector<Tensor *> &tensors,
const std::vector<::pir::Value> &values,
paddle::framework::Scope *scope) {
auto names = GetNameFromValue(block, values, false);
// NOTE(SigureMo): If the program has an inplace chain connecting
// an input value to an output value, the output value will be
// replaced with the input value, so we set the `allow_input` to
// `true` in `GetNameFromValue`
auto names = GetNameFromValue(block, values, true, true);
for (size_t i = 0; i < tensors.size(); ++i) {
auto &name = names[i];
auto &value = values[i];
VLOG(2) << "share " << name << " from scope";
VLOG(4) << "Share Tensor From Scope: " << name;

if (value.impl() == nullptr) {
// skip stop_gradient.
continue;
Expand Down Expand Up @@ -524,20 +530,20 @@ inline void PirRunProgramAPI(
// *backward_program);

// update interpretercore skip_gc_var
auto skip_names =
details::GetNameFromValue(forward_global_block, middle_values, false);
auto skip_names = details::GetNameFromValue(
forward_global_block, middle_values, false, true);
auto skip_names_set =
std::set<std::string>(skip_names.begin(), skip_names.end());
auto no_need_buffer_values = PADDLE_GET_CONST(std::vector<::pir::Value>,
attrs.at("no_need_buffers"));
auto no_need_buffer_names = details::GetNameFromValue(
forward_global_block, no_need_buffer_values, false);
forward_global_block, no_need_buffer_values, false, true);
for (auto &name : no_need_buffer_names) {
VLOG(4) << "Find no need buffer vars with name:" << name;
skip_names_set.erase(name);
}
skip_names =
details::GetNameFromValue(forward_global_block, output_values, false);
skip_names = details::GetNameFromValue(
forward_global_block, output_values, false, true);
skip_names_set.insert(skip_names.begin(), skip_names.end());
details::print_collection(skip_names_set);
interpreter_core->SetSkipGcVars(skip_names_set);
Expand Down Expand Up @@ -1127,11 +1133,11 @@ inline void PirRunProgramGradAPI(

// get all eager gc vars
std::set<std::string> skip_eager_delete_vars;
auto skip_names =
details::GetNameFromValue(backward_global_block, x_grad_values, false);
auto skip_names = details::GetNameFromValue(
backward_global_block, x_grad_values, false, true);
skip_eager_delete_vars.insert(skip_names.begin(), skip_names.end());
skip_names =
details::GetNameFromValue(backward_global_block, p_grad_values, false);
skip_names = details::GetNameFromValue(
backward_global_block, p_grad_values, false, true);
skip_eager_delete_vars.insert(skip_names.begin(), skip_names.end());
interpreter_core->SetSkipGcVars(skip_eager_delete_vars);
cache.UpdateSkipEagerDeleteVars(program_id,
Expand Down
Loading