Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
723 changes: 304 additions & 419 deletions paddle/fluid/framework/details/inplace_op_pass.cc

Large diffs are not rendered by default.

99 changes: 0 additions & 99 deletions paddle/fluid/framework/details/inplace_op_pass.h

This file was deleted.

1 change: 1 addition & 0 deletions paddle/fluid/framework/details/op_graph_view.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ bool OpGraphView::VisitAllPendingOps(OpHandleBase *op,
if (!callback(pending_op)) {
return false;
}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefer "while" loop than "do while" loop.

There is q.push(op) before the while loop. It's no harm to write
while (!q.empty()) {
...
}

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

q.push(pending_op);
}
}
} while (!q.empty());
Expand Down
200 changes: 108 additions & 92 deletions paddle/fluid/framework/details/reference_count_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,82 +118,6 @@ class ShrinkDepsOpFunctor {
const OpGraphView graph_;
};

/**
* Find the nearest downstream computation op handle. If the op is a
* computation op, just return itself.
*/
static ComputationOpHandle *FindNextComputationOpHandleOrReturnItself(
OpHandleBase *op, size_t scope_idx) {
std::queue<OpHandleBase *> q;
std::unordered_set<OpHandleBase *> visited;
q.push(op);
do {
auto *op = q.front();
q.pop();
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op);
if (compute_op != nullptr && compute_op->GetScopeIdx() == scope_idx) {
return compute_op;
}
for (auto *out_var : op->Outputs()) {
for (auto *pending_op : out_var->PendingOps()) {
if (visited.count(pending_op)) continue;
visited.insert(pending_op);
q.push(pending_op);
}
}
} while (!q.empty());
return nullptr;
}

static std::unordered_set<ComputationOpHandle *>
ExtractComputationOpFromLastLivedVar(VarHandle *var, size_t scope_idx,
const ShrinkDepsOpFunctor &shrink_func,
bool *ok) {
// stage one. Get last op for variable.
std::unordered_set<OpHandleBase *> candidates;
{
if (var->PendingOps().empty() && var->GeneratedOp()) {
// No operator depends on this variable. So the last operator is the op
// who generates this variable.
candidates.emplace(var->GeneratedOp());
} else {
candidates = var->PendingOps();
}

// No pending ops or generated op is nullptr
if (candidates.empty()) {
*ok = false;
return {};
}
}

// stage two. Try to cast them to computation op.
// return (*ok=false) when failed.
//
// The reason why we cannot make any types of op handle to be the last lived
// op is:
// some op handle may operate on many DeviceContext, however, our garbage
// collector can only wait one DeviceContext for now. So currently, we wait
// the nearest compute op.
std::unordered_set<ComputationOpHandle *> computation_op;
{
for (auto *op : candidates) {
auto *compute_op =
FindNextComputationOpHandleOrReturnItself(op, scope_idx);
if (compute_op == nullptr) {
*ok = false;
return {};
}
computation_op.emplace(compute_op);
}
}

// stage three. Try to shrink computation op if they depend on each other.
// Get the smallest set of the most ops.
*ok = true;
return shrink_func(computation_op);
}

/**
* Shrink op dependencies according to no need buffer vars.
*
Expand Down Expand Up @@ -267,6 +191,99 @@ static bool ShrinkNoNeedBufferVarOpDependency(
}
}

/**
* Find the nearest downstream computation op handle. If the op is a
* computation op, just return itself.
*/
static ComputationOpHandle *FindNextComputationOpHandleOrReturnItself(
OpHandleBase *op, size_t scope_idx) {
std::queue<OpHandleBase *> q;
std::unordered_set<OpHandleBase *> visited;
q.push(op);
do {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefer "while" loop than "do while" loop.

There is q.push(op) before the while loop. It's no harm to write
while (!q.empty()) {
...
}

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

auto *op = q.front();
q.pop();
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op);
if (compute_op != nullptr && compute_op->GetScopeIdx() == scope_idx) {
return compute_op;
}
for (auto *out_var : op->Outputs()) {
for (auto *pending_op : out_var->PendingOps()) {
if (visited.count(pending_op)) continue;
visited.insert(pending_op);
q.push(pending_op);
}
}
} while (!q.empty());
return nullptr;
}

enum LastLiveOpSearchStatus { kSuccess, kFailure, kShouldPrecede };

static std::unordered_set<ComputationOpHandle *>
ExtractComputationOpFromLastLivedVar(VarHandle *var, size_t scope_idx,
const std::string &var_name,
const ShrinkDepsOpFunctor &shrink_func,
LastLiveOpSearchStatus *status) {
// stage one. Get last op for variable.
std::unordered_set<OpHandleBase *> candidates;
{
if (var->PendingOps().empty() && var->GeneratedOp()) {
// No operator depends on this variable. So the last operator is the op
// who generates this variable.
candidates.emplace(var->GeneratedOp());
} else {
candidates = var->PendingOps();
}

// No pending ops or generated op is nullptr
if (candidates.empty()) {
*status = LastLiveOpSearchStatus::kFailure;
return {};
}
}

// stage two. Try to cast them to computation op.
// return (*status=kFailure) when failed.
//
// The reason why we cannot make any types of op handle to be the last lived
// op is:
// some op handle may operate on many DeviceContext, however, our garbage
// collector can only wait one DeviceContext for now. So currently, we wait
// the nearest compute op.
std::unordered_set<ComputationOpHandle *> computation_op;
{
for (auto *op : candidates) {
auto *compute_op =
FindNextComputationOpHandleOrReturnItself(op, scope_idx);
if (compute_op == nullptr) {
*status = LastLiveOpSearchStatus::kFailure;
return {};
}
computation_op.emplace(compute_op);
}
}

// stage three. Try to shrink computation op if any of them does
// not need the buffer of var_name.
// If all computation ops do not need the buffer of var_name,
// return empty computation op set, and mark the status as kShouldPrecede,
// which means that the last living ops of var_name should be
// found in the previous version of var_name.
if (ShrinkNoNeedBufferVarOpDependency(var_name, &computation_op)) {
*status = LastLiveOpSearchStatus::kShouldPrecede;
return {};
}

PADDLE_ENFORCE(!computation_op.empty(),
"Computation ops should not be empty");

// stage four. Try to shrink computation op if they depend on each other.
// Get the smallest set of the most ops.
*status = LastLiveOpSearchStatus::kSuccess;
return shrink_func(computation_op);
}

void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const {
auto &ref_cnts = Get<std::vector<ReferenceCountMap>>(kGlobalReferenceCount);
auto &last_live_ops_of_vars =
Expand All @@ -284,12 +301,12 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const {
ShrinkDepsOpFunctor shrink_func(
ir::FilterByNodeWrapper<OpHandleBase>(*graph));

VLOG(1) << "Place number: " << vars.size();
for (size_t i = 0; i < vars.size(); ++i) {
for (auto &name_var_pair : vars[i]) {
// Whether this variable can be reused or deleted? If not, we do not
// compute reference counts and dependencies.
VarDesc *var_desc = TryGetLatestVarDesc(name_var_pair.second);

if (var_desc == nullptr || var_desc->Persistable()) {
continue;
}
Expand All @@ -305,34 +322,33 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const {
auto &var_name = name_var_pair.first;
auto &var_handles = name_var_pair.second;

PADDLE_ENFORCE_EQ(var_desc->Name(), var_name);

for (auto iter = var_handles.rbegin(); iter != var_handles.rend();
++iter) {
bool ok;
auto result =
ExtractComputationOpFromLastLivedVar(*iter, i, shrink_func, &ok);
VLOG(10) << "Try to find last living ops of " << var_name << " "
<< (iter - var_handles.rbegin()) << " time";
LastLiveOpSearchStatus status = LastLiveOpSearchStatus::kFailure;
auto result = ExtractComputationOpFromLastLivedVar(
*iter, i, var_name, shrink_func, &status);

// Seldomly, some vars may have no pending or preceding computation ops
// Just break;
if (!ok) break;
VLOG(10) << "Extract " << result.size() << " ops of var " << var_name;
if (status == LastLiveOpSearchStatus::kFailure) {
break;
}

size_t original_op_deps = result.size();
// If all ops do not need buffer of var_name, calculate reference count
// of the previous version of var_name.
if (ShrinkNoNeedBufferVarOpDependency(var_name, &result)) {
if (status == LastLiveOpSearchStatus::kShouldPrecede) {
VLOG(10) << "Try to precede reference count computing at var "
<< var_name;
continue;
}

size_t final_op_deps = result.size();
if (final_op_deps < original_op_deps) {
VLOG(5) << "Shrink op deps from " << original_op_deps << " to "
<< final_op_deps;
}

PADDLE_ENFORCE_EQ(status, LastLiveOpSearchStatus::kSuccess);
PADDLE_ENFORCE(!result.empty(), "Last living ops of %s cannot be empty",
var_name);

VLOG(10) << "Extract " << result.size() << " ops of var " << var_name;
ref_cnts[i].emplace(var_name, result.size());
last_live_ops_of_vars[i].emplace(var_name, std::move(result));
break;
Expand Down
Loading