-
Notifications
You must be signed in to change notification settings - Fork 6k
Rewrite inplace pass and fix gc bug #17126
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
ee49695
b72ce96
fefb1c4
9ab045a
0a47464
4765b61
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
This file was deleted.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -118,82 +118,6 @@ class ShrinkDepsOpFunctor { | |
| const OpGraphView graph_; | ||
| }; | ||
|
|
||
| /** | ||
| * Find the nearest downstream computation op handle. If the op is a | ||
| * computation op, just return itself. | ||
| */ | ||
| static ComputationOpHandle *FindNextComputationOpHandleOrReturnItself( | ||
| OpHandleBase *op, size_t scope_idx) { | ||
| std::queue<OpHandleBase *> q; | ||
| std::unordered_set<OpHandleBase *> visited; | ||
| q.push(op); | ||
| do { | ||
| auto *op = q.front(); | ||
| q.pop(); | ||
| auto *compute_op = dynamic_cast<ComputationOpHandle *>(op); | ||
| if (compute_op != nullptr && compute_op->GetScopeIdx() == scope_idx) { | ||
| return compute_op; | ||
| } | ||
| for (auto *out_var : op->Outputs()) { | ||
| for (auto *pending_op : out_var->PendingOps()) { | ||
| if (visited.count(pending_op)) continue; | ||
| visited.insert(pending_op); | ||
| q.push(pending_op); | ||
| } | ||
| } | ||
| } while (!q.empty()); | ||
| return nullptr; | ||
| } | ||
|
|
||
| static std::unordered_set<ComputationOpHandle *> | ||
| ExtractComputationOpFromLastLivedVar(VarHandle *var, size_t scope_idx, | ||
| const ShrinkDepsOpFunctor &shrink_func, | ||
| bool *ok) { | ||
| // stage one. Get last op for variable. | ||
| std::unordered_set<OpHandleBase *> candidates; | ||
| { | ||
| if (var->PendingOps().empty() && var->GeneratedOp()) { | ||
| // No operator depends on this variable. So the last operator is the op | ||
| // who generates this variable. | ||
| candidates.emplace(var->GeneratedOp()); | ||
| } else { | ||
| candidates = var->PendingOps(); | ||
| } | ||
|
|
||
| // No pending ops or generated op is nullptr | ||
| if (candidates.empty()) { | ||
| *ok = false; | ||
| return {}; | ||
| } | ||
| } | ||
|
|
||
| // stage two. Try to cast them to computation op. | ||
| // return (*ok=false) when failed. | ||
| // | ||
| // The reason why we cannot make any types of op handle to be the last lived | ||
| // op is: | ||
| // some op handle may operate on many DeviceContext, however, our garbage | ||
| // collector can only wait one DeviceContext for now. So currently, we wait | ||
| // the nearest compute op. | ||
| std::unordered_set<ComputationOpHandle *> computation_op; | ||
| { | ||
| for (auto *op : candidates) { | ||
| auto *compute_op = | ||
| FindNextComputationOpHandleOrReturnItself(op, scope_idx); | ||
| if (compute_op == nullptr) { | ||
| *ok = false; | ||
| return {}; | ||
| } | ||
| computation_op.emplace(compute_op); | ||
| } | ||
| } | ||
|
|
||
| // stage three. Try to shrink computation op if they depend on each other. | ||
| // Get the smallest set of the most ops. | ||
| *ok = true; | ||
| return shrink_func(computation_op); | ||
| } | ||
|
|
||
| /** | ||
| * Shrink op dependencies according to no need buffer vars. | ||
| * | ||
|
|
@@ -267,6 +191,99 @@ static bool ShrinkNoNeedBufferVarOpDependency( | |
| } | ||
| } | ||
|
|
||
| /** | ||
| * Find the nearest downstream computation op handle. If the op is a | ||
| * computation op, just return itself. | ||
| */ | ||
| static ComputationOpHandle *FindNextComputationOpHandleOrReturnItself( | ||
| OpHandleBase *op, size_t scope_idx) { | ||
| std::queue<OpHandleBase *> q; | ||
| std::unordered_set<OpHandleBase *> visited; | ||
| q.push(op); | ||
| do { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Prefer "while" loop than "do while" loop. There is q.push(op) before the while loop. It's no harm to write
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
| auto *op = q.front(); | ||
| q.pop(); | ||
| auto *compute_op = dynamic_cast<ComputationOpHandle *>(op); | ||
| if (compute_op != nullptr && compute_op->GetScopeIdx() == scope_idx) { | ||
| return compute_op; | ||
| } | ||
| for (auto *out_var : op->Outputs()) { | ||
| for (auto *pending_op : out_var->PendingOps()) { | ||
| if (visited.count(pending_op)) continue; | ||
| visited.insert(pending_op); | ||
| q.push(pending_op); | ||
| } | ||
| } | ||
| } while (!q.empty()); | ||
| return nullptr; | ||
| } | ||
|
|
||
| enum LastLiveOpSearchStatus { kSuccess, kFailure, kShouldPrecede }; | ||
|
|
||
| static std::unordered_set<ComputationOpHandle *> | ||
| ExtractComputationOpFromLastLivedVar(VarHandle *var, size_t scope_idx, | ||
| const std::string &var_name, | ||
| const ShrinkDepsOpFunctor &shrink_func, | ||
| LastLiveOpSearchStatus *status) { | ||
| // stage one. Get last op for variable. | ||
| std::unordered_set<OpHandleBase *> candidates; | ||
| { | ||
| if (var->PendingOps().empty() && var->GeneratedOp()) { | ||
| // No operator depends on this variable. So the last operator is the op | ||
| // who generates this variable. | ||
| candidates.emplace(var->GeneratedOp()); | ||
| } else { | ||
| candidates = var->PendingOps(); | ||
| } | ||
|
|
||
| // No pending ops or generated op is nullptr | ||
| if (candidates.empty()) { | ||
| *status = LastLiveOpSearchStatus::kFailure; | ||
| return {}; | ||
| } | ||
| } | ||
|
|
||
| // stage two. Try to cast them to computation op. | ||
| // return (*status=kFailure) when failed. | ||
| // | ||
| // The reason why we cannot make any types of op handle to be the last lived | ||
| // op is: | ||
| // some op handle may operate on many DeviceContext, however, our garbage | ||
| // collector can only wait one DeviceContext for now. So currently, we wait | ||
| // the nearest compute op. | ||
| std::unordered_set<ComputationOpHandle *> computation_op; | ||
| { | ||
| for (auto *op : candidates) { | ||
| auto *compute_op = | ||
| FindNextComputationOpHandleOrReturnItself(op, scope_idx); | ||
| if (compute_op == nullptr) { | ||
| *status = LastLiveOpSearchStatus::kFailure; | ||
| return {}; | ||
| } | ||
| computation_op.emplace(compute_op); | ||
| } | ||
| } | ||
|
|
||
| // stage three. Try to shrink computation op if any of them does | ||
| // not need the buffer of var_name. | ||
| // If all computation ops do not need the buffer of var_name, | ||
| // return empty computation op set, and mark the status as kShouldPrecede, | ||
| // which means that the last living ops of var_name should be | ||
| // found in the previous version of var_name. | ||
| if (ShrinkNoNeedBufferVarOpDependency(var_name, &computation_op)) { | ||
| *status = LastLiveOpSearchStatus::kShouldPrecede; | ||
| return {}; | ||
| } | ||
|
|
||
| PADDLE_ENFORCE(!computation_op.empty(), | ||
| "Computation ops should not be empty"); | ||
|
|
||
| // stage four. Try to shrink computation op if they depend on each other. | ||
| // Get the smallest set of the most ops. | ||
| *status = LastLiveOpSearchStatus::kSuccess; | ||
| return shrink_func(computation_op); | ||
| } | ||
|
|
||
| void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const { | ||
| auto &ref_cnts = Get<std::vector<ReferenceCountMap>>(kGlobalReferenceCount); | ||
| auto &last_live_ops_of_vars = | ||
|
|
@@ -284,12 +301,12 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const { | |
| ShrinkDepsOpFunctor shrink_func( | ||
| ir::FilterByNodeWrapper<OpHandleBase>(*graph)); | ||
|
|
||
| VLOG(1) << "Place number: " << vars.size(); | ||
| for (size_t i = 0; i < vars.size(); ++i) { | ||
| for (auto &name_var_pair : vars[i]) { | ||
| // Whether this variable can be reused or deleted? If not, we do not | ||
| // compute reference counts and dependencies. | ||
| VarDesc *var_desc = TryGetLatestVarDesc(name_var_pair.second); | ||
|
|
||
| if (var_desc == nullptr || var_desc->Persistable()) { | ||
| continue; | ||
| } | ||
|
|
@@ -305,34 +322,33 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const { | |
| auto &var_name = name_var_pair.first; | ||
| auto &var_handles = name_var_pair.second; | ||
|
|
||
| PADDLE_ENFORCE_EQ(var_desc->Name(), var_name); | ||
|
|
||
| for (auto iter = var_handles.rbegin(); iter != var_handles.rend(); | ||
| ++iter) { | ||
| bool ok; | ||
| auto result = | ||
| ExtractComputationOpFromLastLivedVar(*iter, i, shrink_func, &ok); | ||
| VLOG(10) << "Try to find last living ops of " << var_name << " " | ||
| << (iter - var_handles.rbegin()) << " time"; | ||
| LastLiveOpSearchStatus status = LastLiveOpSearchStatus::kFailure; | ||
| auto result = ExtractComputationOpFromLastLivedVar( | ||
| *iter, i, var_name, shrink_func, &status); | ||
|
|
||
| // Seldomly, some vars may have no pending or preceding computation ops | ||
| // Just break; | ||
| if (!ok) break; | ||
| VLOG(10) << "Extract " << result.size() << " ops of var " << var_name; | ||
| if (status == LastLiveOpSearchStatus::kFailure) { | ||
| break; | ||
| } | ||
|
|
||
| size_t original_op_deps = result.size(); | ||
| // If all ops do not need buffer of var_name, calculate reference count | ||
| // of the previous version of var_name. | ||
| if (ShrinkNoNeedBufferVarOpDependency(var_name, &result)) { | ||
| if (status == LastLiveOpSearchStatus::kShouldPrecede) { | ||
| VLOG(10) << "Try to precede reference count computing at var " | ||
| << var_name; | ||
| continue; | ||
| } | ||
|
|
||
| size_t final_op_deps = result.size(); | ||
| if (final_op_deps < original_op_deps) { | ||
| VLOG(5) << "Shrink op deps from " << original_op_deps << " to " | ||
| << final_op_deps; | ||
| } | ||
|
|
||
| PADDLE_ENFORCE_EQ(status, LastLiveOpSearchStatus::kSuccess); | ||
| PADDLE_ENFORCE(!result.empty(), "Last living ops of %s cannot be empty", | ||
| var_name); | ||
|
|
||
| VLOG(10) << "Extract " << result.size() << " ops of var " << var_name; | ||
| ref_cnts[i].emplace(var_name, result.size()); | ||
| last_live_ops_of_vars[i].emplace(var_name, std::move(result)); | ||
| break; | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Prefer "while" loop than "do while" loop.
There is q.push(op) before the while loop. It's no harm to write
while (!q.empty()) {
...
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.