-
Notifications
You must be signed in to change notification settings - Fork 6k
Add async ssa graph executor #15409
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add async ssa graph executor #15409
Changes from 38 commits
92a6c7a
afda840
ea66979
88d71fa
69484f7
f3210b6
ada43e8
fab8457
a66115b
62549e0
be738a6
9da96ab
7e145b7
02dab46
4a17261
249f48e
d6c0dca
16af1db
b1fe8d4
e72637d
84367cf
c4ded17
2171aa7
cc71e89
31a05d3
9465c3d
7f3be09
12f6b8c
f4f4816
ecedd53
b5b8e6c
10393dd
43c8237
cf0511f
dab7f36
ff01d70
f768fbf
847e4f4
e70b172
8744f9a
b2c082c
e92ad8a
f28c258
c09477b
4e218da
5e8de51
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,146 @@ | ||
| // Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. | ||
| // | ||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||
| // you may not use this file except in compliance with the License. | ||
| // You may obtain a copy of the License at | ||
| // | ||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||
| // | ||
| // Unless required by applicable law or agreed to in writing, software | ||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| // See the License for the specific language governing permissions and | ||
| // limitations under the License. | ||
|
|
||
| #include "paddle/fluid/framework/details/async_ssa_graph_executor.h" | ||
|
|
||
| #include "paddle/fluid/framework/variable_helper.h" | ||
|
|
||
| namespace paddle { | ||
| namespace framework { | ||
| namespace details { | ||
|
|
||
| inline void NewTempScopeAndInitVars(const std::vector<VarInfo> &var_infos, | ||
| Scope *scope) { | ||
| Scope &local_scope = scope->NewScope(); | ||
| *scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>() = | ||
| &local_scope; | ||
|
|
||
| for (auto &info : var_infos) { | ||
| if (scope->FindVar(info.name_) != nullptr) { | ||
| continue; | ||
| } | ||
|
|
||
| if (info.persistable_) { // Persistable | ||
| InitializeVariable(scope->Var(info.name_), info.type_); | ||
| } else { | ||
| InitializeVariable(local_scope.Var(info.name_), info.type_); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| AsyncSSAGraphExecutor::AsyncSSAGraphExecutor( | ||
| const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes, | ||
| const std::vector<platform::Place> &places, std::vector<ir::Graph *> graphs) | ||
| : strategy_(std::move(strategy)), | ||
| local_scopes_(std::move(local_scopes)), | ||
| pool_(places.size() >= 2 ? new ::ThreadPool(places.size()) : nullptr), | ||
| places_(std::move(places)), | ||
| graphs_(std::move(graphs)) { | ||
| VLOG(3) << "build AsyncSSAGraphExecutor"; | ||
| PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size()); | ||
|
|
||
| // set the correct size of thread pool to each device. | ||
| strategy_.num_threads_ = strategy_.num_threads_ < places_.size() | ||
| ? 1UL | ||
| : strategy_.num_threads_ / places_.size(); | ||
| VLOG(1) << "set num_threads: " << strategy_.num_threads_ | ||
| << " to run the operators of the graph on each device."; | ||
| for (size_t i = 0; i < places.size(); ++i) { | ||
| executors_.emplace_back(new details::ThreadedSSAGraphExecutor( | ||
| strategy_, {local_scopes_[i]}, {places_[i]}, graphs_[i])); | ||
| } | ||
|
|
||
| for (auto &node : graphs_[0]->Nodes()) { | ||
| if (node->IsVar() && !node->IsCtrlVar() && node->Var()) { | ||
| var_infos_.emplace_back(); | ||
| var_infos_.back().name_ = node->Var()->Name(); | ||
| var_infos_.back().type_ = node->Var()->GetType(); | ||
| var_infos_.back().persistable_ = node->Var()->Persistable(); | ||
| } | ||
| } | ||
| for (auto *scope : local_scopes_) { | ||
| NewTempScopeAndInitVars(var_infos_, scope); | ||
| } | ||
| } | ||
|
|
||
| void AsyncSSAGraphExecutor::StartOffPythonTrainLoop() { | ||
| VLOG(3) << "StartOffPythonTrainLoop size = " << places_.size(); | ||
| for (size_t i = 1; i < places_.size(); ++i) { | ||
| auto call = [this, i]() -> void { | ||
| VLOG(3) << "start off python thread " << i; | ||
| try { | ||
| while (true) { | ||
| executors_[i]->Run({}); | ||
| } | ||
| } catch (...) { | ||
| exception_holder_.Catch(std::current_exception()); | ||
| VLOG(3) << "get exception type = " << exception_holder_.Type(); | ||
| } | ||
| VLOG(3) << "thread " << i << " exited!"; | ||
| }; | ||
| run_futures_.emplace_back(pool_->enqueue(std::move(call))); | ||
| } | ||
| } | ||
|
|
||
| void AsyncSSAGraphExecutor::HandleException() { | ||
| if (exception_holder_.IsCaught()) { | ||
| for (auto &f : run_futures_) { | ||
| VLOG(3) << "wait future"; | ||
| f.wait(); | ||
| } | ||
| VLOG(3) << "caught exception " << exception_holder_.Type() | ||
| << ", rethrow it"; | ||
| run_futures_.clear(); | ||
| exception_holder_.ReThrow(); | ||
| } | ||
| } | ||
|
|
||
| FeedFetchList AsyncSSAGraphExecutor::Run( | ||
| const std::vector<std::string> &fetch_tensors) { | ||
| // init once | ||
| if (run_futures_.size() == 0 && places_.size() > 1) { | ||
| exception_holder_.Clear(); | ||
| StartOffPythonTrainLoop(); | ||
| } | ||
|
|
||
| if (places_.size() == 1) { | ||
| exception_holder_.Clear(); | ||
| } else { | ||
| HandleException(); | ||
| } | ||
|
|
||
| FeedFetchList fetch_data; | ||
| fetch_data.reserve(fetch_tensors.size()); | ||
|
|
||
| try { | ||
| fetch_data = executors_[0]->Run(fetch_tensors); | ||
| } catch (...) { | ||
| exception_holder_.Catch(std::current_exception()); | ||
| } | ||
|
|
||
| HandleException(); | ||
|
|
||
| FeedFetchList ret; | ||
| for (size_t fetch_idx = 0; fetch_idx < fetch_tensors.size(); ++fetch_idx) { | ||
| std::vector<const LoDTensor *> lodtensor_ptrs; | ||
| lodtensor_ptrs.push_back(&fetch_data.at(fetch_idx)); | ||
| ret.emplace_back(); | ||
| ret.back().MergeLoDTensor(lodtensor_ptrs, platform::CPUPlace()); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. num_iteration_per_run_ > 1的情况下,各线程执行速度不一致,merge各个local_scope的结果是否有意义?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个感觉可以去掉其实,反正已经是纯异步了,相当于减少一点做evel的数据量
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 大家步调不一致,参数版本也不一致,确实应该去掉,观察其中一个线程就够了 |
||
| } | ||
| return ret; | ||
| } | ||
|
|
||
| } // namespace details | ||
| } // namespace framework | ||
| } // namespace paddle | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,63 @@ | ||
| // Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. | ||
| // | ||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||
| // you may not use this file except in compliance with the License. | ||
| // You may obtain a copy of the License at | ||
| // | ||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||
| // | ||
| // Unless required by applicable law or agreed to in writing, software | ||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| // See the License for the specific language governing permissions and | ||
| // limitations under the License. | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <string> | ||
| #include <vector> | ||
|
|
||
| #include "ThreadPool.h" | ||
| #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" | ||
|
|
||
| namespace paddle { | ||
| namespace framework { | ||
| namespace details { | ||
|
|
||
| struct VarInfo { | ||
| std::string name_; | ||
| proto::VarType::Type type_; | ||
| bool persistable_; | ||
| }; | ||
|
|
||
| class AsyncSSAGraphExecutor : public SSAGraphExecutor { | ||
| public: | ||
| AsyncSSAGraphExecutor(const ExecutionStrategy &strategy, | ||
| const std::vector<Scope *> &local_scopes, | ||
| const std::vector<platform::Place> &places, | ||
| std::vector<ir::Graph *> graphs); | ||
| ~AsyncSSAGraphExecutor() final = default; | ||
| const ir::Graph &Graph() const override { return *graphs_[0]; } | ||
|
|
||
| FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override; | ||
|
|
||
| private: | ||
| void StartOffPythonTrainLoop(); | ||
| void HandleException(); | ||
|
|
||
| private: | ||
| ExecutionStrategy strategy_; | ||
| std::vector<Scope *> local_scopes_; | ||
| std::unique_ptr<::ThreadPool> pool_{nullptr}; | ||
| std::vector<platform::Place> places_; | ||
| std::vector<ir::Graph *> graphs_; | ||
|
|
||
| std::vector<std::unique_ptr<details::ThreadedSSAGraphExecutor>> executors_; | ||
| ExceptionHolder exception_holder_; | ||
| std::vector<std::future<void>> run_futures_; | ||
| std::vector<VarInfo> var_infos_; | ||
| }; | ||
|
|
||
| } // namespace details | ||
| } // namespace framework | ||
| } // namespace paddle |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -90,6 +90,7 @@ struct BuildStrategy { | |
| // num_trainers is 1, so the current fields of build_strategy doesn't tell if | ||
| // it's distributed model. | ||
| bool is_distribution_{false}; | ||
| bool async_mode_{false}; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what is the relationship between async_mode and is_distribution |
||
| int num_trainers_{1}; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can num_trainers > 1 and not is_distribution? |
||
| int trainer_id_{0}; | ||
| std::vector<std::string> trainers_endpoints_; | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
local_scopes是从哪里创建带入的?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/framework/parallel_executor.cc#L217