-
Notifications
You must be signed in to change notification settings - Fork 6k
Refine ParallelGraph Execution #15716
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
f3463ec
88d3dc9
73005ee
ecdd116
bd0d44a
7cd6de3
642fd68
5677c9d
0f8bd73
d5090c8
4b193db
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -34,9 +34,11 @@ namespace details { | |
| static inline bool SeqOnlyAllReduceOps(const BuildStrategy &strategy) { | ||
| // Should fix the allreduce op order if scheduling | ||
| // them in multiple threads or processes to avoid hang. | ||
| // NOTE: ParallelExecutor would execute this pass on each graph, so | ||
| // don't need to append it here. | ||
| return (!strategy.enable_sequential_execution_ && | ||
| strategy.num_trainers_ > 1) || | ||
| strategy.enable_parallel_graph_; | ||
| strategy.num_trainers_ > 1) && | ||
| !strategy.enable_parallel_graph_; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. PG needs to execute this pass on each graph. |
||
| } | ||
|
|
||
| class ParallelExecutorPassBuilder : public ir::PassBuilder { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -36,11 +36,6 @@ namespace framework { | |
| namespace details { | ||
|
|
||
| namespace { | ||
| // TODO(panyx0718): Clean this up as well. | ||
| // all operators. NOTE that even we use a vector here, the operators is | ||
| // unordered. | ||
| typedef std::vector<OpHandleBase *> GraphOps; | ||
| const char kGraphOps[] = "ops"; | ||
|
|
||
| bool OpHaveRole(const ir::Node &node, const framework::OpRole &role) { | ||
| return boost::get<int>( | ||
|
|
@@ -226,7 +221,6 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl( | |
| * Only variables should be the leaves of graph. | ||
| */ | ||
| AddOutputToLeafOps(&result); | ||
| result.Erase(kGraphOps); | ||
| return graph; | ||
| } | ||
|
|
||
|
|
@@ -392,19 +386,33 @@ void MultiDevSSAGraphBuilderBase::CreateComputationalOp(ir::Graph *result, | |
|
|
||
| void MultiDevSSAGraphBuilderBase::CreateAllReduceOp( | ||
| ir::Graph *result, const std::string &og) const { | ||
| OpHandleBase *op_handle = nullptr; | ||
|
|
||
| auto append_allreduce_op = [&]( | ||
| std::vector<Scope *> &scopes, | ||
| std::vector<platform::Place> &places) -> OpHandleBase * { | ||
| #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) | ||
| result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle( | ||
| result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation), | ||
| local_scopes_, places_, nccl_ctxs_)); | ||
| result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle( | ||
| result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation), | ||
| scopes, places, nccl_ctxs_)); | ||
| #else | ||
| result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle( | ||
| result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation), | ||
| local_scopes_, places_)); | ||
| result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle( | ||
| result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation), | ||
| scopes, places)); | ||
| #endif | ||
| auto *op_handle = result->Get<GraphOps>(kGraphOps).back(); | ||
| return result->Get<GraphOps>(kGraphOps).back(); | ||
| }; | ||
|
|
||
| if (!strategy_.enable_parallel_graph_) | ||
| op_handle = append_allreduce_op(local_scopes_, places_); | ||
|
|
||
| for (size_t i = 0; i < places_.size(); ++i) { | ||
| auto &p = places_[i]; | ||
| auto p = places_[i]; | ||
| std::vector<Scope *> ss{local_scopes_[i]}; | ||
| std::vector<platform::Place> ps{p}; | ||
|
||
| if (strategy_.enable_parallel_graph_) | ||
| op_handle = append_allreduce_op(ss, ps); | ||
|
|
||
| SetCommunicationContext(op_handle, p); | ||
| auto &vars = result->Get<GraphVars>(kGraphVars)[i][og]; | ||
| PADDLE_ENFORCE(!vars.empty()); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -36,13 +36,20 @@ namespace details { | |
| // map from variable name to variables. The variables, who have the same name, | ||
| // will have a differsent version. The offset in the | ||
| // `std::vector<VarHandle*>` is the version of varaibles. | ||
| typedef std::vector<std::unordered_map<std::string, std::vector<VarHandle*>>> | ||
| typedef std::vector<std::unordered_map<std::string, std::vector<VarHandle *>>> | ||
| GraphVars; | ||
| const char kGraphVars[] = "vars"; | ||
|
|
||
| // aux variables to represent dependency. Useful to resolve data hazard. | ||
| typedef std::unordered_set<VarHandleBase*> GraphDepVars; | ||
| typedef std::unordered_set<VarHandleBase *> GraphDepVars; | ||
| const char kGraphDepVars[] = "dep_vars"; | ||
|
|
||
| // TODO(panyx0718): Clean this up as well. | ||
| // all operators. NOTE that even we use a vector here, the operators is | ||
| // unordered. | ||
| typedef std::vector<OpHandleBase *> GraphOps; | ||
| const char kGraphOps[] = "ops"; | ||
|
||
|
|
||
| } // namespace details | ||
| } // namespace framework | ||
| } // namespace paddle | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,6 +26,11 @@ limitations under the License. */ | |
|
|
||
| namespace paddle { | ||
| namespace framework { | ||
|
|
||
| namespace details { | ||
| constexpr char kAllOpDescs[] = "all_op_descs"; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. doc here that this is not recommended?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done. |
||
| } // namespace details | ||
|
|
||
| namespace ir { | ||
|
|
||
| /* | ||
|
|
@@ -168,10 +173,13 @@ class Graph { | |
| return ret; | ||
| } | ||
|
|
||
| void RemoveNode(ir::Node *node) { | ||
| std::unique_ptr<ir::Node> RemoveNode(ir::Node *node) { | ||
| PADDLE_ENFORCE(node_set_.find(node) != node_set_.end()); | ||
| node_set_.erase(node); | ||
| std::unique_ptr<ir::Node> ret; | ||
| ret.reset(nodes_.at(node).release()); | ||
| nodes_.erase(node); | ||
| node_set_.erase(node); | ||
| return ret; | ||
| } | ||
|
|
||
| // NOTE low performance, but simple and secure. | ||
|
|
@@ -184,13 +192,6 @@ class Graph { | |
| return nullptr; | ||
| } | ||
|
|
||
| void ResolveHazard( | ||
| const std::map<std::string, std::vector<ir::Node *>> &var_nodes); | ||
|
|
||
| private: | ||
| std::map<std::string, std::vector<ir::Node *>> InitFromProgram( | ||
| const ProgramDesc &program); | ||
|
|
||
| // This method takes ownership of `node`. | ||
| ir::Node *AddNode(ir::Node *node) { | ||
| PADDLE_ENFORCE(node_set_.find(node) == node_set_.end()); | ||
|
|
@@ -199,6 +200,13 @@ class Graph { | |
| return node; | ||
| } | ||
|
|
||
| void ResolveHazard( | ||
| const std::map<std::string, std::vector<ir::Node *>> &var_nodes); | ||
|
|
||
| private: | ||
| std::map<std::string, std::vector<ir::Node *>> InitFromProgram( | ||
| const ProgramDesc &program); | ||
|
|
||
| // NOTE: program_ shouldn't be exposed to user. | ||
| const ProgramDesc program_; | ||
| std::map<std::string, boost::any> attrs_; | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ParallelGraphExecutor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.