-
Notifications
You must be signed in to change notification settings - Fork 6k
Refine ParallelGraph Execution #15716
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
f3463ec
88d3dc9
73005ee
ecdd116
bd0d44a
7cd6de3
642fd68
5677c9d
0f8bd73
d5090c8
4b193db
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -35,8 +35,8 @@ static inline bool SeqOnlyAllReduceOps(const BuildStrategy &strategy) { | |
| // Should fix the allreduce op order if scheduling | ||
| // them in multiple threads or processes to avoid hang. | ||
| return (!strategy.enable_sequential_execution_ && | ||
| strategy.num_trainers_ > 1) || | ||
| strategy.enable_parallel_graph_; | ||
| strategy.num_trainers_ > 1) && | ||
| !strategy.enable_parallel_graph_; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. PG needs to execute this pass on each graph. |
||
| } | ||
|
|
||
| class ParallelExecutorPassBuilder : public ir::PassBuilder { | ||
|
|
@@ -118,7 +118,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { | |
| } | ||
|
|
||
| // Verify that the graph is correct for multi-device executor. | ||
| AppendPass("multi_devices_check_pass"); | ||
| auto multi_devices_pass = AppendPass("multi_devices_check_pass"); | ||
|
|
||
| if (SeqOnlyAllReduceOps(strategy)) { | ||
| AppendPass("all_reduce_deps_pass"); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -36,11 +36,6 @@ namespace framework { | |
| namespace details { | ||
|
|
||
| namespace { | ||
| // TODO(panyx0718): Clean this up as well. | ||
| // all operators. NOTE that even we use a vector here, the operators is | ||
| // unordered. | ||
| typedef std::vector<OpHandleBase *> GraphOps; | ||
| const char kGraphOps[] = "ops"; | ||
|
|
||
| bool OpHaveRole(const ir::Node &node, const framework::OpRole &role) { | ||
| return boost::get<int>( | ||
|
|
@@ -226,7 +221,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl( | |
| * Only variables should be the leaves of graph. | ||
| */ | ||
| AddOutputToLeafOps(&result); | ||
| result.Erase(kGraphOps); | ||
| // result.Erase(kGraphOps); | ||
|
||
| return graph; | ||
| } | ||
|
|
||
|
|
@@ -392,19 +387,33 @@ void MultiDevSSAGraphBuilderBase::CreateComputationalOp(ir::Graph *result, | |
|
|
||
| void MultiDevSSAGraphBuilderBase::CreateAllReduceOp( | ||
| ir::Graph *result, const std::string &og) const { | ||
| OpHandleBase *op_handle = nullptr; | ||
|
|
||
| auto append_allreduce_op = [&]( | ||
| std::vector<Scope *> &scopes, | ||
| std::vector<platform::Place> &places) -> OpHandleBase * { | ||
| #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) | ||
| result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle( | ||
| result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation), | ||
| local_scopes_, places_, nccl_ctxs_)); | ||
| result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle( | ||
| result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation), | ||
| scopes, places, nccl_ctxs_)); | ||
| #else | ||
| result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle( | ||
| result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation), | ||
| local_scopes_, places_)); | ||
| result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle( | ||
| result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation), | ||
| scopes, places)); | ||
| #endif | ||
| auto *op_handle = result->Get<GraphOps>(kGraphOps).back(); | ||
| return result->Get<GraphOps>(kGraphOps).back(); | ||
| }; | ||
|
|
||
| if (!strategy_.enable_parallel_graph_) | ||
| op_handle = append_allreduce_op(local_scopes_, places_); | ||
|
|
||
| for (size_t i = 0; i < places_.size(); ++i) { | ||
| auto &p = places_[i]; | ||
| auto p = places_[i]; | ||
| std::vector<Scope *> ss{local_scopes_[i]}; | ||
| std::vector<platform::Place> ps{p}; | ||
|
||
| if (strategy_.enable_parallel_graph_) | ||
| op_handle = append_allreduce_op(ss, ps); | ||
|
|
||
| SetCommunicationContext(op_handle, p); | ||
| auto &vars = result->Get<GraphVars>(kGraphVars)[i][og]; | ||
| PADDLE_ENFORCE(!vars.empty()); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -36,13 +36,20 @@ namespace details { | |
| // map from variable name to variables. The variables, who have the same name, | ||
| // will have a differsent version. The offset in the | ||
| // `std::vector<VarHandle*>` is the version of varaibles. | ||
| typedef std::vector<std::unordered_map<std::string, std::vector<VarHandle*>>> | ||
| typedef std::vector<std::unordered_map<std::string, std::vector<VarHandle *>>> | ||
| GraphVars; | ||
| const char kGraphVars[] = "vars"; | ||
|
|
||
| // aux variables to represent dependency. Useful to resolve data hazard. | ||
| typedef std::unordered_set<VarHandleBase*> GraphDepVars; | ||
| typedef std::unordered_set<VarHandleBase *> GraphDepVars; | ||
| const char kGraphDepVars[] = "dep_vars"; | ||
|
|
||
| // TODO(panyx0718): Clean this up as well. | ||
| // all operators. NOTE that even we use a vector here, the operators is | ||
| // unordered. | ||
| typedef std::vector<OpHandleBase *> GraphOps; | ||
| const char kGraphOps[] = "ops"; | ||
|
||
|
|
||
| } // namespace details | ||
| } // namespace framework | ||
| } // namespace paddle | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,11 +13,70 @@ | |
| // limitations under the License. | ||
|
|
||
| #include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h" | ||
| #include "paddle/fluid/framework/ir/graph_helper.h" | ||
|
|
||
| namespace paddle { | ||
| namespace framework { | ||
| namespace details { | ||
|
|
||
| std::vector<std::unique_ptr<ir::Graph>> SeparateMultiDevicesGraph( | ||
|
||
| const std::vector<platform::Place> &places, | ||
| std::unique_ptr<ir::Graph> graph) { | ||
| std::vector<std::unique_ptr<ir::Graph>> graphs; | ||
| graphs.reserve(places.size()); | ||
| for (size_t i = 0; i < places.size(); ++i) { | ||
| ProgramDesc empty; | ||
| graphs.emplace_back(std::unique_ptr<ir::Graph>(new ir::Graph(empty))); | ||
| auto &g = graphs.back(); | ||
| g->Set(kGraphVars, new GraphVars(1UL)); | ||
| g->Set(kGraphDepVars, new GraphDepVars); | ||
| g->Set(kGraphOps, new GraphOps); | ||
| } | ||
|
|
||
| for (auto &op : graph->Get<GraphOps>(kGraphOps)) { | ||
| auto &dev_ctx = op->DeviceContext(); | ||
| auto &p = dev_ctx.begin()->first; | ||
| int dev_id = boost::get<platform::CUDAPlace>(p).device; | ||
| auto &dev_ops = graphs[dev_id]->Get<GraphOps>(kGraphOps); | ||
| auto &dev_dummys = graphs[dev_id]->Get<GraphDepVars>(kGraphDepVars); | ||
| dev_ops.emplace_back(op); | ||
| graphs[dev_id]->AddNode(graph->RemoveNode(op->Node()).release()); | ||
|
|
||
| for (auto &var : op->Inputs()) { | ||
| auto dummy_ptr = dynamic_cast<DummyVarHandle *>(var); | ||
| if (dummy_ptr) { | ||
| dev_dummys.insert(var); | ||
| if (graph->Nodes().count(var->Node())) | ||
| graphs[dev_id]->AddNode(graph->RemoveNode(var->Node()).release()); | ||
| } | ||
| } | ||
| for (auto &var : op->Outputs()) { | ||
| auto dummy_ptr = dynamic_cast<DummyVarHandle *>(var); | ||
| if (dummy_ptr) { | ||
| dev_dummys.insert(var); | ||
| if (graph->Nodes().count(var->Node())) | ||
| graphs[dev_id]->AddNode(graph->RemoveNode(var->Node()).release()); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| for (size_t dev_id = 0; dev_id < places.size(); ++dev_id) { | ||
| auto &dev_vars = graphs[dev_id]->Get<GraphVars>(kGraphVars)[0]; | ||
| auto &origin_vars = graph->Get<GraphVars>(kGraphVars)[dev_id]; | ||
| for (auto &name_pair : origin_vars) { | ||
| dev_vars.emplace(name_pair.first, name_pair.second); | ||
| for (auto &version_pair : name_pair.second) { | ||
| if (graph->Nodes().count(version_pair->Node())) { | ||
| graphs[dev_id]->AddNode( | ||
| graph->RemoveNode(version_pair->Node()).release()); | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| return graphs; | ||
| } | ||
|
|
||
| ParallelSSAGraphExecutor::ParallelSSAGraphExecutor( | ||
| const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes, | ||
| const std::vector<platform::Place> &places, | ||
|
|
@@ -37,7 +96,7 @@ ParallelSSAGraphExecutor::ParallelSSAGraphExecutor( | |
| << " to run the operators of the graph on each device."; | ||
| for (size_t i = 0; i < places.size(); ++i) { | ||
| executors_.emplace_back(new details::ThreadedSSAGraphExecutor( | ||
| strategy_, {local_scopes_[i]}, {places_[i]}, std::move(graphs_[i]))); | ||
| strategy_, local_scopes_, {places_[i]}, std::move(graphs_.at(i)))); | ||
| } | ||
| } | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this also in memory_optimizer_helper.h, unify them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, moved to graph.h