-
Notifications
You must be signed in to change notification settings - Fork 6k
Add async ssa graph executor communicator #16172
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
92a6c7a
afda840
ea66979
88d71fa
69484f7
7021979
9958775
b5aefc8
f3210b6
be72940
ca5d96b
1866d2d
74040cb
1edc042
ada43e8
fab8457
a66115b
62549e0
be738a6
9da96ab
7e145b7
02dab46
4a17261
c7e3868
657a4f9
249f48e
d6c0dca
381f383
16af1db
b1fe8d4
741b7cf
4356f18
5c36eb8
5cf0092
a0585d0
a804a2a
a715261
fbd186b
8bda4ab
e72637d
84367cf
c4ded17
2171aa7
cc71e89
31a05d3
9465c3d
7f3be09
12f6b8c
f4f4816
ecedd53
b5b8e6c
10393dd
b8491bf
43c8237
cf0511f
dab7f36
ff01d70
f768fbf
49f2f4f
02425b2
847e4f4
3691a46
9573d61
3c6b733
c2cce6b
5060150
13e8b5b
e70b172
8744f9a
fab1b54
8c38aca
b2c082c
e92ad8a
f28c258
c09477b
4e218da
5e8de51
255b36d
7d5dc4e
a0bb18b
a23f1ee
446fdf9
fe6a840
3225e19
ff8054c
c0e5941
63cd70a
0a828fe
eb6af30
ad5a2b3
43378ad
d3a1437
23d3929
9b74707
0fcdae8
c567deb
347178b
065b68b
ea0df4e
039d783
3061840
37f6b9a
d640c6c
392e97a
b542639
33be014
b68f840
34890fd
d8974e6
61912e8
a1821a0
df45c8c
8342f12
9db1a9e
baf0232
adf272b
fb6cc3a
9861a92
4031c1a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,203 @@ | ||
| // Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. | ||
| // | ||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||
| // you may not use this file except in compliance with the License. | ||
| // You may obtain a copy of the License at | ||
| // | ||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||
| // | ||
| // Unless required by applicable law or agreed to in writing, software | ||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| // See the License for the specific language governing permissions and | ||
| // limitations under the License. | ||
|
|
||
| #include "paddle/fluid/framework/details/async_ssa_graph_executor.h" | ||
|
|
||
| #include "paddle/fluid/framework/variable_helper.h" | ||
|
|
||
| #ifdef PADDLE_WITH_DISTRIBUTE | ||
| #include "paddle/fluid/operators/distributed/communicator.h" | ||
| #endif | ||
|
|
||
| namespace paddle { | ||
| namespace framework { | ||
| namespace details { | ||
|
|
||
| inline void NewTempScopeAndInitVars(const std::vector<VarInfo> &var_infos, | ||
| Scope *scope) { | ||
| VLOG(3) << "NewTempScopeAndInitVars"; | ||
| Scope &local_scope = scope->NewScope(); | ||
| *scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>() = | ||
| &local_scope; | ||
|
|
||
| for (auto &info : var_infos) { | ||
| if (scope->FindVar(info.name_) != nullptr) { | ||
| continue; | ||
| } | ||
|
|
||
| if (info.persistable_) { // Persistable | ||
| InitializeVariable(scope->Var(info.name_), info.type_); | ||
| } else { | ||
| InitializeVariable(local_scope.Var(info.name_), info.type_); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // get RpcContext and remote send and recv op | ||
| void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) { | ||
| #ifdef PADDLE_WITH_DISTRIBUTE | ||
| using RpcCtxMap = operators::distributed::RpcCtxMap; | ||
| VLOG(3) << "ProcessGraph"; | ||
| RpcCtxMap send_varname_to_ctx; | ||
| RpcCtxMap recv_varname_to_ctx; | ||
| for (auto i = 0; i < graphs.size(); ++i) { | ||
| std::vector<ir::Node *> nodes_to_delete; | ||
| for (auto &node : graphs[i]->Nodes()) { | ||
| VLOG(3) << "node name " << node->Name(); | ||
| if (node && node->IsOp()) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why the node maybe
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have met this problem before, so I add more check to ensure it works right. |
||
| if (node->Name() == "send") { | ||
| auto send_var_name = node->Op()->Input("X")[0]; | ||
| auto send_varnames = boost::get<std::vector<std::string>>( | ||
| node->Op()->GetNullableAttr("send_varnames")); | ||
| auto epmap = boost::get<std::vector<std::string>>( | ||
| node->Op()->GetNullableAttr("epmap")); | ||
| auto height_section = boost::get<std::vector<int64_t>>( | ||
| node->Op()->GetNullableAttr("sections")); | ||
| send_varname_to_ctx[send_var_name] = | ||
| operators::distributed::RpcContext(send_var_name, send_varnames, | ||
| epmap, height_section); | ||
| VLOG(3) << "find and init an send op: " | ||
| << send_varname_to_ctx[send_var_name]; | ||
| } else if (node->Name() == "recv") { | ||
| auto recv_var_name = node->Op()->Output("Out")[0]; | ||
| auto recv_varnames = boost::get<std::vector<std::string>>( | ||
| node->Op()->GetNullableAttr("recv_varnames")); | ||
| auto epmap = boost::get<std::vector<std::string>>( | ||
| node->Op()->GetNullableAttr("epmap")); | ||
| recv_varname_to_ctx[recv_var_name] = | ||
| operators::distributed::RpcContext(recv_var_name, recv_varnames, | ||
| epmap, {}); | ||
| nodes_to_delete.push_back(node); | ||
| VLOG(3) << "find and remove an recv op: " | ||
| << recv_varname_to_ctx[recv_var_name]; | ||
| } | ||
| } | ||
| } | ||
| } | ||
| // init communicator here | ||
| if (send_varname_to_ctx.size() > 0) { | ||
| VLOG(3) << "this is distribute mode, will use communicator"; | ||
| operators::distributed::Communicator::Init(send_varname_to_ctx, | ||
| recv_varname_to_ctx, scope); | ||
| operators::distributed::Communicator::GetInstance()->Start(); | ||
| } | ||
| #endif | ||
| } | ||
|
|
||
| AsyncSSAGraphExecutor::AsyncSSAGraphExecutor( | ||
| const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes, | ||
| const std::vector<platform::Place> &places, std::vector<ir::Graph *> graphs) | ||
| : strategy_(std::move(strategy)), | ||
| local_scopes_(std::move(local_scopes)), | ||
| pool_(places.size() >= 2 ? new ::ThreadPool(places.size()) : nullptr), | ||
| places_(std::move(places)), | ||
| graphs_(std::move(graphs)) { | ||
| VLOG(3) << "build AsyncSSAGraphExecutor"; | ||
| PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size()); | ||
|
|
||
| // set the correct size of thread pool to each device. | ||
| strategy_.num_threads_ = strategy_.num_threads_ < places_.size() | ||
| ? 1UL | ||
| : strategy_.num_threads_ / places_.size(); | ||
| VLOG(1) << "set num_threads: " << strategy_.num_threads_ | ||
| << " to run the operators of the graph on each device."; | ||
| for (size_t i = 0; i < places.size(); ++i) { | ||
| executors_.emplace_back(new details::ThreadedSSAGraphExecutor( | ||
| strategy_, {local_scopes_[i]}, {places_[i]}, graphs_[i])); | ||
| } | ||
|
|
||
| for (auto &node : graphs_[0]->Nodes()) { | ||
| if (node->IsVar() && !node->IsCtrlVar() && node->Var()) { | ||
| var_infos_.emplace_back(); | ||
| var_infos_.back().name_ = node->Var()->Name(); | ||
| var_infos_.back().type_ = node->Var()->GetType(); | ||
| var_infos_.back().persistable_ = node->Var()->Persistable(); | ||
| } | ||
| } | ||
| for (auto *scope : local_scopes_) { | ||
| NewTempScopeAndInitVars(var_infos_, scope); | ||
| } | ||
| ProcessGraph(graphs_, local_scopes_[0]); | ||
| } | ||
|
|
||
| void AsyncSSAGraphExecutor::StartOffPythonTrainLoop() { | ||
| VLOG(3) << "StartOffPythonTrainLoop size = " << places_.size(); | ||
| for (size_t i = 1; i < places_.size(); ++i) { | ||
| auto call = [this, i]() -> void { | ||
| VLOG(3) << "start off python thread " << i; | ||
| try { | ||
| while (true) { | ||
| executors_[i]->Run({}); | ||
| } | ||
| } catch (...) { | ||
| exception_holder_.Catch(std::current_exception()); | ||
| VLOG(3) << "get exception type = " << exception_holder_.Type(); | ||
| } | ||
| VLOG(3) << "thread " << i << " exited!"; | ||
| }; | ||
| run_futures_.emplace_back(pool_->enqueue(std::move(call))); | ||
| } | ||
| } | ||
|
|
||
| void AsyncSSAGraphExecutor::HandleException() { | ||
| if (exception_holder_.IsCaught()) { | ||
| for (auto &f : run_futures_) { | ||
| VLOG(3) << "wait future"; | ||
| f.wait(); | ||
| } | ||
| VLOG(3) << "caught exception " << exception_holder_.Type() | ||
| << ", rethrow it"; | ||
| run_futures_.clear(); | ||
| exception_holder_.ReThrow(); | ||
| } | ||
| } | ||
|
|
||
| FeedFetchList AsyncSSAGraphExecutor::Run( | ||
| const std::vector<std::string> &fetch_tensors) { | ||
| // init once | ||
| if (run_futures_.size() == 0 && places_.size() > 1) { | ||
| exception_holder_.Clear(); | ||
| StartOffPythonTrainLoop(); | ||
| } | ||
|
|
||
| if (places_.size() == 1) { | ||
| exception_holder_.Clear(); | ||
| } else { | ||
| HandleException(); | ||
| } | ||
|
|
||
| FeedFetchList fetch_data; | ||
| fetch_data.reserve(fetch_tensors.size()); | ||
|
|
||
| try { | ||
| fetch_data = executors_[0]->Run(fetch_tensors); | ||
| } catch (...) { | ||
| exception_holder_.Catch(std::current_exception()); | ||
| } | ||
|
|
||
| HandleException(); | ||
|
|
||
| FeedFetchList ret; | ||
| for (size_t fetch_idx = 0; fetch_idx < fetch_tensors.size(); ++fetch_idx) { | ||
| std::vector<const LoDTensor *> lodtensor_ptrs; | ||
| lodtensor_ptrs.push_back(&fetch_data.at(fetch_idx)); | ||
| ret.emplace_back(); | ||
| ret.back().MergeLoDTensor(lodtensor_ptrs, platform::CPUPlace()); | ||
| } | ||
| return ret; | ||
| } | ||
|
|
||
| } // namespace details | ||
| } // namespace framework | ||
| } // namespace paddle | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,65 @@ | ||
| // Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. | ||
| // | ||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||
| // you may not use this file except in compliance with the License. | ||
| // You may obtain a copy of the License at | ||
| // | ||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||
| // | ||
| // Unless required by applicable law or agreed to in writing, software | ||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| // See the License for the specific language governing permissions and | ||
| // limitations under the License. | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <memory> | ||
| #include <string> | ||
| #include <utility> | ||
| #include <vector> | ||
|
|
||
| #include "ThreadPool.h" | ||
| #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" | ||
|
|
||
| namespace paddle { | ||
| namespace framework { | ||
| namespace details { | ||
|
|
||
| struct VarInfo { | ||
| std::string name_; | ||
| proto::VarType::Type type_; | ||
| bool persistable_; | ||
| }; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Duplicate with here. |
||
|
|
||
| class AsyncSSAGraphExecutor : public SSAGraphExecutor { | ||
| public: | ||
| AsyncSSAGraphExecutor(const ExecutionStrategy &strategy, | ||
| const std::vector<Scope *> &local_scopes, | ||
| const std::vector<platform::Place> &places, | ||
| std::vector<ir::Graph *> graphs); | ||
| ~AsyncSSAGraphExecutor() final = default; | ||
| const ir::Graph &Graph() const override { return *graphs_[0]; } | ||
|
|
||
| FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override; | ||
|
|
||
| private: | ||
| void StartOffPythonTrainLoop(); | ||
| void HandleException(); | ||
|
|
||
| private: | ||
| ExecutionStrategy strategy_; | ||
| std::vector<Scope *> local_scopes_; | ||
| std::unique_ptr<::ThreadPool> pool_{nullptr}; | ||
| std::vector<platform::Place> places_; | ||
| std::vector<ir::Graph *> graphs_; | ||
|
|
||
| std::vector<std::unique_ptr<details::ThreadedSSAGraphExecutor>> executors_; | ||
| ExceptionHolder exception_holder_; | ||
| std::vector<std::future<void>> run_futures_; | ||
| std::vector<VarInfo> var_infos_; | ||
| }; | ||
|
|
||
| } // namespace details | ||
| } // namespace framework | ||
| } // namespace paddle | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -31,6 +31,8 @@ struct ExecutionStrategy { | |
| size_t num_iteration_per_drop_scope_{1}; | ||
| ExecutorType type_{kDefault}; | ||
| bool dry_run_{false}; | ||
| size_t num_iteration_per_run_{1}; // only use with async_ssa_graph_executor | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's this mean?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. https://github.com/PaddlePaddle/Paddle/pull/16172/files#diff-bcb7058cf667aba60603c4448e6180c8R131 |
||
| // and pyreader with data queue | ||
| }; | ||
|
|
||
| } // namespace details | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2018 -> 2019