Skip to content
16 changes: 15 additions & 1 deletion paddle/operators/detail/grpc_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
SendProcessor* s = new SendProcessor(ch);
s->Prepare(var_h, time_out);
s->response_call_back_ = NULL;

auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, (void*)s);
});
Expand Down Expand Up @@ -97,6 +96,21 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
return true;
}

bool RPCClient::AsyncBatchBarrier(const std::string& ep, int64_t time_out) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BatchBarrier是个名词,看这个实现应该是AsyncSendBatchBarrier,或者这个实现直接改成一个同步的调用会比较方便。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, 同时支持同步和异步两种请求的Server可能会比较复杂,现在的方案可以复用AsyncSendVariable接口来发送barrier signal,代码会简洁很多,如果后续有强需求支持同步的请求,再实现同步的接口?

const auto ch = GetChannel(ep);

BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
s->Prepare(time_out);

sendrecv::VariableMessage req;
req.set_varname(BATCH_BARRIER_MESSAGE);
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, (void*)s);
req_count_++;

return true;
}

bool RPCClient::Wait() {
if (req_count_ <= 0) {
return true;
Expand Down
23 changes: 23 additions & 0 deletions paddle/operators/detail/grpc_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,15 @@ class ClientBase {
context_->set_deadline(deadline);
}

virtual void Prepare(int64_t time_out) {
context_.reset(new grpc::ClientContext());

std::chrono::system_clock::time_point deadline =
std::chrono::system_clock::now() + std::chrono::milliseconds(time_out);

context_->set_deadline(deadline);
}

virtual void Process() = 0;

std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
Expand Down Expand Up @@ -117,6 +126,17 @@ class GetProcessor : public ClientBase {
RequestGetCallBack response_call_back_ = ProcGetResponse;
};

class BatchBarrierProcessor : public ClientBase {
public:
explicit BatchBarrierProcessor(std::shared_ptr<grpc::Channel> ch)
: ClientBase(ch) {}

virtual ~BatchBarrierProcessor() {}

virtual void Process() {}
sendrecv::VoidMessage reply_;
};

class RPCClient {
public:
bool AsyncSendVariable(const std::string& ep,
Expand All @@ -130,6 +150,9 @@ class RPCClient {
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out = 600 * 1000);

bool AsyncBatchBarrier(const std::string& ep, int64_t time_out = 600 * 1000);

bool Wait();

private:
Expand Down
14 changes: 8 additions & 6 deletions paddle/operators/detail/grpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ void AsyncGRPCServer::RunSyncUpdate() {

cq_send_ = builder.AddCompletionQueue();
cq_get_ = builder.AddCompletionQueue();
cq_batch_barrier_ = builder.AddCompletionQueue();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cq_batch_barrier_ seems never used.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, sorry I forgot to delete the older code.


server_ = builder.BuildAndStart();
LOG(INFO) << "Server listening on " << address_ << std::endl;

Expand All @@ -141,11 +143,11 @@ void AsyncGRPCServer::RunSyncUpdate() {
std::bind(&AsyncGRPCServer::TryToRegisterNewGetOne, this);

t_send_.reset(
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, false,
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this,
cq_send_.get(), "cq_send", send_register)));

t_get_.reset(
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, true,
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this,
cq_get_.get(), "cq_get", get_register)));

// wait server
Expand Down Expand Up @@ -174,7 +176,7 @@ void AsyncGRPCServer::TryToRegisterNewSendOne() {
}
RequestSend* send =
new RequestSend(&service_, cq_send_.get(), &var_recv_queue_);
VLOG(4) << "create RequestSend status:" << send->Status();
VLOG(4) << "Create RequestSend status:" << send->Status();
}

void AsyncGRPCServer::TryToRegisterNewGetOne() {
Expand All @@ -184,11 +186,11 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
}
RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_,
&var_get_queue_);
VLOG(4) << "create Requestget status:" << get->Status();
VLOG(4) << "Create RequestGet status:" << get->Status();
}

// FIXME(typhoonzero): remove wait argument and change cq_name to enum.
void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
// FIXME(typhoonzero): change cq_name to enum.
void AsyncGRPCServer::HandleRequest(grpc::ServerCompletionQueue* cq,
std::string cq_name,
std::function<void()> TryToRegisterNewOne) {
TryToRegisterNewOne();
Expand Down
6 changes: 4 additions & 2 deletions paddle/operators/detail/grpc_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,12 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {

void Push(const MessageWithName &msg) { this->var_recv_queue_.Push(msg); }

bool IsRecvQueueEmpty() { return this->var_recv_queue_.IsEmpty(); }

void ShutDown();

protected:
void HandleRequest(bool wait, grpc::ServerCompletionQueue *cq,
std::string cq_name,
void HandleRequest(grpc::ServerCompletionQueue *cq, std::string cq_name,
std::function<void()> TryToRegisterNewOne);
void TryToRegisterNewSendOne();
void TryToRegisterNewGetOne();
Expand All @@ -69,6 +70,7 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
volatile bool is_shut_down_ = false;
std::unique_ptr<grpc::ServerCompletionQueue> cq_send_;
std::unique_ptr<grpc::ServerCompletionQueue> cq_get_;
std::unique_ptr<grpc::ServerCompletionQueue> cq_batch_barrier_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cq_batch_barrier_ seems never used.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


sendrecv::SendRecvService::AsyncService service_;
std::unique_ptr<grpc::Server> server_;
Expand Down
3 changes: 3 additions & 0 deletions paddle/operators/detail/sendrecvop_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ namespace paddle {
namespace operators {
namespace detail {

#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"

void SerializeToMessage(const std::string& name, const framework::Variable* var,
const platform::DeviceContext& ctx,
sendrecv::VariableMessage* msg);
Expand Down
5 changes: 5 additions & 0 deletions paddle/operators/detail/simple_block_queue.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ class SimpleBlockQueue {
this->queue_.pop_back();
return rc;
}

bool IsEmpty() {
std::unique_lock<std::mutex> lock(this->mutex_);
return this->queue_.empty();
}
};

} // namespace detail
Expand Down
17 changes: 12 additions & 5 deletions paddle/operators/recv_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ limitations under the License. */
#include "paddle/operators/detail/simple_block_queue.h"
#include "paddle/string/printf.h"

#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"

namespace paddle {
namespace operators {

Expand Down Expand Up @@ -95,27 +93,33 @@ class RecvOp : public framework::OperatorBase {
auto param_list = Attr<std::vector<std::string>>("ParamList");
auto grad_list = Attr<std::vector<std::string>>("GradList");
auto fan_in = Attr<int>("Fanin");
size_t param_count = param_list.size();

auto *block = Attr<framework::BlockDesc *>(kOptimizeBlock);
auto *program = block->Program();
framework::Executor executor(dev_place);

// TODO(typhoonzero): change this to a while_op for every cluster-batch.
bool exit_flag = false;
size_t barrier_size = param_count * fan_in;
while (!exit_flag) {
// Get from multiple trainers, we don't care about the order in which
// the gradients arrives, just add suffix 0~n and merge the gradient.
rpc_service_->SetCond(0);
for (size_t i = 0; i < barrier_size; ++i) {
size_t barrier_size = 0;
int batch_barrier = 0;
while (batch_barrier != fan_in || !rpc_service_->IsRecvQueueEmpty()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why !rpc_service_->IsRecvQueueEmpty() is needed. rpc_service_->Get() is a blocking call which will wait until a new message arrives.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, it's not used :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I deleted !rpc_service_->IsRecvQueueEmpty(), because send op would send barrier signal by least, if RecvOp received barrier signal, it should be the least message from one trainer.

const detail::MessageWithName &v = rpc_service_->Get();
auto grad_var_name = v.first;
if (grad_var_name == LISTEN_TERMINATE_MESSAGE) {
LOG(INFO) << "received terminate message and exit";
exit_flag = true;
break;
}
if (grad_var_name == BATCH_BARRIER_MESSAGE) {
VLOG(3) << "recv batch barrier message";
batch_barrier++;
continue;
}
barrier_size++;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

barrier_size is used only for printing log, can remove or rename it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

auto it = std::find(grad_list.begin(), grad_list.end(), grad_var_name);
std::string param_var_name;
if (it != grad_list.end()) {
Expand All @@ -125,6 +129,7 @@ class RecvOp : public framework::OperatorBase {
}
VLOG(3) << "received grad: " << grad_var_name
<< " updating param: " << param_var_name;

if (fan_in > 1) {
grad_var_name = this->GetGradVarNameForTrainer(grad_var_name);
}
Expand All @@ -135,6 +140,8 @@ class RecvOp : public framework::OperatorBase {
}
detail::DeserializeFromMessage(v.second, dev_ctx, var);
}
VLOG(3) << "recv " << barrier_size << " parmeters for one barrier.";
// TODO(Yancey1989): merge SelectedRows variables here
if (exit_flag) {
break;
}
Expand Down
11 changes: 9 additions & 2 deletions paddle/operators/send_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,20 @@ class SendOp : public framework::OperatorBase {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
for (size_t i = 0; i < ins.size(); i++) {
VLOG(3) << "sending " << ins[i];
VLOG(3) << "sending " << ins[i] << " to " << epmap[i];
client_.AsyncSendVariable(epmap[i], ctx, scope, ins[i]);
}
PADDLE_ENFORCE(client_.Wait());

std::set<std::string> epset(epmap.begin(), epmap.end());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use endpoints attribute is the same thing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

for (auto& ep : epset) {
VLOG(3) << "batch barrier, ep: " << ep;
client_.AsyncBatchBarrier(ep);
}
PADDLE_ENFORCE(client_.Wait());

for (size_t i = 0; i < outs.size(); i++) {
VLOG(3) << "getting " << outs[i];
VLOG(3) << "getting " << outs[i] << " from " << epmap[i];
client_.AsyncGetVariable(epmap[i], ctx, scope, outs[i]);
}

Expand Down