-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Batch barrier in send/recv op #7847
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
6e1fa48
b346c8c
745ec2b
586a06d
0eb9f80
d8551c0
f917403
6b51936
782d048
840bd1f
e4c0de0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -132,6 +132,8 @@ void AsyncGRPCServer::RunSyncUpdate() { | |
|
|
||
| cq_send_ = builder.AddCompletionQueue(); | ||
| cq_get_ = builder.AddCompletionQueue(); | ||
| cq_batch_barrier_ = builder.AddCompletionQueue(); | ||
|
||
|
|
||
| server_ = builder.BuildAndStart(); | ||
| LOG(INFO) << "Server listening on " << address_ << std::endl; | ||
|
|
||
|
|
@@ -141,11 +143,11 @@ void AsyncGRPCServer::RunSyncUpdate() { | |
| std::bind(&AsyncGRPCServer::TryToRegisterNewGetOne, this); | ||
|
|
||
| t_send_.reset( | ||
| new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, false, | ||
| new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, | ||
| cq_send_.get(), "cq_send", send_register))); | ||
|
|
||
| t_get_.reset( | ||
| new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, true, | ||
| new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, | ||
| cq_get_.get(), "cq_get", get_register))); | ||
|
|
||
| // wait server | ||
|
|
@@ -174,7 +176,7 @@ void AsyncGRPCServer::TryToRegisterNewSendOne() { | |
| } | ||
| RequestSend* send = | ||
| new RequestSend(&service_, cq_send_.get(), &var_recv_queue_); | ||
| VLOG(4) << "create RequestSend status:" << send->Status(); | ||
| VLOG(4) << "Create RequestSend status:" << send->Status(); | ||
| } | ||
|
|
||
| void AsyncGRPCServer::TryToRegisterNewGetOne() { | ||
|
|
@@ -184,11 +186,11 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() { | |
| } | ||
| RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_, | ||
| &var_get_queue_); | ||
| VLOG(4) << "create Requestget status:" << get->Status(); | ||
| VLOG(4) << "Create RequestGet status:" << get->Status(); | ||
| } | ||
|
|
||
| // FIXME(typhoonzero): remove wait argument and change cq_name to enum. | ||
| void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq, | ||
| // FIXME(typhoonzero): change cq_name to enum. | ||
| void AsyncGRPCServer::HandleRequest(grpc::ServerCompletionQueue* cq, | ||
| std::string cq_name, | ||
| std::function<void()> TryToRegisterNewOne) { | ||
| TryToRegisterNewOne(); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -54,11 +54,12 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service { | |
|
|
||
| void Push(const MessageWithName &msg) { this->var_recv_queue_.Push(msg); } | ||
|
|
||
| bool IsRecvQueueEmpty() { return this->var_recv_queue_.IsEmpty(); } | ||
|
|
||
| void ShutDown(); | ||
|
|
||
| protected: | ||
| void HandleRequest(bool wait, grpc::ServerCompletionQueue *cq, | ||
| std::string cq_name, | ||
| void HandleRequest(grpc::ServerCompletionQueue *cq, std::string cq_name, | ||
| std::function<void()> TryToRegisterNewOne); | ||
| void TryToRegisterNewSendOne(); | ||
| void TryToRegisterNewGetOne(); | ||
|
|
@@ -69,6 +70,7 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service { | |
| volatile bool is_shut_down_ = false; | ||
| std::unique_ptr<grpc::ServerCompletionQueue> cq_send_; | ||
| std::unique_ptr<grpc::ServerCompletionQueue> cq_get_; | ||
| std::unique_ptr<grpc::ServerCompletionQueue> cq_batch_barrier_; | ||
|
||
|
|
||
| sendrecv::SendRecvService::AsyncService service_; | ||
| std::unique_ptr<grpc::Server> server_; | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,8 +29,6 @@ limitations under the License. */ | |
| #include "paddle/operators/detail/simple_block_queue.h" | ||
| #include "paddle/string/printf.h" | ||
|
|
||
| #define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV" | ||
|
|
||
| namespace paddle { | ||
| namespace operators { | ||
|
|
||
|
|
@@ -95,27 +93,33 @@ class RecvOp : public framework::OperatorBase { | |
| auto param_list = Attr<std::vector<std::string>>("ParamList"); | ||
| auto grad_list = Attr<std::vector<std::string>>("GradList"); | ||
| auto fan_in = Attr<int>("Fanin"); | ||
| size_t param_count = param_list.size(); | ||
|
|
||
| auto *block = Attr<framework::BlockDesc *>(kOptimizeBlock); | ||
| auto *program = block->Program(); | ||
| framework::Executor executor(dev_place); | ||
|
|
||
| // TODO(typhoonzero): change this to a while_op for every cluster-batch. | ||
| bool exit_flag = false; | ||
| size_t barrier_size = param_count * fan_in; | ||
| while (!exit_flag) { | ||
| // Get from multiple trainers, we don't care about the order in which | ||
| // the gradients arrives, just add suffix 0~n and merge the gradient. | ||
| rpc_service_->SetCond(0); | ||
| for (size_t i = 0; i < barrier_size; ++i) { | ||
| size_t barrier_size = 0; | ||
| int batch_barrier = 0; | ||
| while (batch_barrier != fan_in || !rpc_service_->IsRecvQueueEmpty()) { | ||
|
||
| const detail::MessageWithName &v = rpc_service_->Get(); | ||
| auto grad_var_name = v.first; | ||
| if (grad_var_name == LISTEN_TERMINATE_MESSAGE) { | ||
| LOG(INFO) << "received terminate message and exit"; | ||
| exit_flag = true; | ||
| break; | ||
| } | ||
| if (grad_var_name == BATCH_BARRIER_MESSAGE) { | ||
| VLOG(3) << "recv batch barrier message"; | ||
| batch_barrier++; | ||
| continue; | ||
| } | ||
| barrier_size++; | ||
|
||
| auto it = std::find(grad_list.begin(), grad_list.end(), grad_var_name); | ||
| std::string param_var_name; | ||
| if (it != grad_list.end()) { | ||
|
|
@@ -125,6 +129,7 @@ class RecvOp : public framework::OperatorBase { | |
| } | ||
| VLOG(3) << "received grad: " << grad_var_name | ||
| << " updating param: " << param_var_name; | ||
|
|
||
| if (fan_in > 1) { | ||
| grad_var_name = this->GetGradVarNameForTrainer(grad_var_name); | ||
| } | ||
|
|
@@ -135,6 +140,8 @@ class RecvOp : public framework::OperatorBase { | |
| } | ||
| detail::DeserializeFromMessage(v.second, dev_ctx, var); | ||
| } | ||
| VLOG(3) << "recv " << barrier_size << " parmeters for one barrier."; | ||
| // TODO(Yancey1989): merge SelectedRows variables here | ||
| if (exit_flag) { | ||
| break; | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -41,13 +41,20 @@ class SendOp : public framework::OperatorBase { | |
| platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); | ||
| auto& ctx = *pool.Get(place); | ||
| for (size_t i = 0; i < ins.size(); i++) { | ||
| VLOG(3) << "sending " << ins[i]; | ||
| VLOG(3) << "sending " << ins[i] << " to " << epmap[i]; | ||
| client_.AsyncSendVariable(epmap[i], ctx, scope, ins[i]); | ||
| } | ||
| PADDLE_ENFORCE(client_.Wait()); | ||
|
|
||
| std::set<std::string> epset(epmap.begin(), epmap.end()); | ||
|
||
| for (auto& ep : epset) { | ||
| VLOG(3) << "batch barrier, ep: " << ep; | ||
| client_.AsyncBatchBarrier(ep); | ||
| } | ||
| PADDLE_ENFORCE(client_.Wait()); | ||
|
|
||
| for (size_t i = 0; i < outs.size(); i++) { | ||
| VLOG(3) << "getting " << outs[i]; | ||
| VLOG(3) << "getting " << outs[i] << " from " << epmap[i]; | ||
| client_.AsyncGetVariable(epmap[i], ctx, scope, outs[i]); | ||
| } | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BatchBarrier是个名词,看这个实现应该是AsyncSendBatchBarrier,或者这个实现直接改成一个同步的调用会比较方便。There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, 同时支持同步和异步两种请求的Server可能会比较复杂,现在的方案可以复用
AsyncSendVariable接口来发送barrier signal,代码会简洁很多,如果后续有强需求支持同步的请求,再实现同步的接口?