Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/distributed/fleet_executor/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ cc_library(
op_registry
executor_gc_helper
gflags
flags
glog
${BRPC_DEPS})

Expand Down
107 changes: 106 additions & 1 deletion paddle/fluid/distributed/fleet_executor/carrier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/framework/variable_helper.h"

#include "gflags/gflags.h"

DECLARE_bool(fleetexecutor_debug_mode);

namespace paddle {
namespace distributed {

Expand All @@ -48,6 +52,69 @@ void Carrier::Init(
thread_num_ = 1;
thread_pool_.SetThreadNum(thread_num_);
thread_pool_.Start();

test_thread_ = std::thread([this]() { loop_to_send_msg(); });
cache_begin_ == std::chrono::steady_clock::now();
}

void Carrier::loop_to_send_msg() {
//VLOG(3) << "loop_send_msg loop now";
while(1){
while(1){
int q_size=0;
std::chrono::time_point<std::chrono::steady_clock> c_begin;
{
std::lock_guard<std::mutex> lock(running_mutex_);
q_size = messages_for_test_.size();
c_begin = cache_begin_;
}

auto now = std::chrono::steady_clock::now();
auto delta = std::chrono::duration_cast<std::chrono::milliseconds>(now - c_begin).count();

if(q_size<2 && delta <5000){
//std::time_t now_c = std::chrono::system_clock::to_time_t(now);
//VLOG(3) << "messages_for_test_ q_size:" << q_size
// << ", delta:" << delta << ", will sleep 1000ms" ;//<<", now:" << now_c;
std::this_thread::sleep_for(std::chrono::milliseconds(10));
continue;
}else{
VLOG(3) << "messages_for_test_ q_size:" << q_size
<< ", delta:" << delta << ", will send all msg" ;
break;
}
}

{
std::lock_guard<std::mutex> lock(running_mutex_);
while (!messages_for_test_.empty()) {
auto msg=messages_for_test_.back();
messages_for_test_.pop_back();

int64_t src_id = msg.src_id();
// TODO(liyurui): compatible solution, will be removed completely in the
// future
if (interceptor_id_to_rank_.find(src_id) == interceptor_id_to_rank_.end() &&
src_id == SOURCE_ID) {
src_id = msg.dst_id();
}
int64_t dst_id = msg.dst_id();
int64_t dst_rank = GetRank(dst_id);

VLOG(3) << "Send a cached message from interceptor " << src_id
<< " to interceptor " << dst_id
<< ", which are in different ranks, scope_idx:" << msg.scope_idx();

if(!GlobalVal<MessageBus>::Get()->Send(dst_rank, msg)){
LOG(FATAL) << "send msg error";
}
std::this_thread::sleep_for(std::chrono::milliseconds(2));
}

cache_begin_ = std::chrono::steady_clock::now();
}
}
VLOG(3) << "reset cache_begin_";
}

void Carrier::Init(
Expand Down Expand Up @@ -95,6 +162,9 @@ void Carrier::Init(
thread_pool_.SetThreadNum(thread_num_);
thread_pool_.Start();

test_thread_ = std::thread([this]() { loop_to_send_msg(); });
cache_begin_ == std::chrono::steady_clock::now();

CreateInterceptors();
is_init_ = true;
}
Expand Down Expand Up @@ -230,12 +300,47 @@ bool Carrier::Send(const InterceptorMessage& msg) {
VLOG(3) << "Send a message from interceptor " << src_id
<< " to interceptor " << dst_id << ", which are in the same ranks.";
return EnqueueInterceptorMessage(msg);
} else {
}
/*
else {
VLOG(3) << "Send a message from interceptor " << src_id
<< " to interceptor " << dst_id
<< ", which are in different ranks.";
return GlobalVal<MessageBus>::Get()->Send(dst_rank, msg);
}
*/

if(!FLAGS_fleetexecutor_debug_mode){
VLOG(3) << "Send a message from interceptor " << src_id
<< " to interceptor " << dst_id
<< ", which are in different ranks.";
return GlobalVal<MessageBus>::Get()->Send(dst_rank, msg);
}

if(msg.message_type() != DATA_IS_READY){
VLOG(3) << "Send a message from interceptor " << src_id
<< " to interceptor " << dst_id
<< ", which are in different ranks.";
return GlobalVal<MessageBus>::Get()->Send(dst_rank, msg);
}

{
VLOG(3) << "prepare executor debug";

std::unique_lock<std::mutex> lock(running_mutex_);
if(messages_for_test_.empty()){
cache_begin_ = std::chrono::steady_clock::now();
//std::time_t now_c = std::chrono::system_clock::to_time_t(cache_begin_));
VLOG(3) << "messages_for_test_ empty, reset cache_begin_";
}

VLOG(3) << "Cache message from interceptor " << src_id
<< " to interceptor " << dst_id
<< ", which are in different ranks, scope_idx:" << msg.scope_idx();
messages_for_test_.emplace_back(msg);
}

return true;
}

Interceptor* Carrier::SetInterceptor(int64_t interceptor_id,
Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/distributed/fleet_executor/carrier.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
#include <string>
#include <unordered_map>
#include <vector>
#include <queue>
#include <thread>

#include <chrono>

#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h"
Expand Down Expand Up @@ -118,6 +122,12 @@ class Carrier final {
int thread_num_;
TaskLoopThreadPool thread_pool_;
std::unordered_set<int64_t> interceptor_ids_;

std::deque<InterceptorMessage> messages_for_test_;
std::thread test_thread_;
std::chrono::time_point<std::chrono::steady_clock> cache_begin_;

void loop_to_send_msg();
};

} // namespace distributed
Expand Down
31 changes: 21 additions & 10 deletions paddle/fluid/distributed/fleet_executor/compute_interceptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ bool ComputeInterceptor::IsInputReady() {
if (!gen_step_to_scope_id_to_finish_flag_.empty()) {
scope_id_to_finish_flag =
gen_step_to_scope_id_to_finish_flag_.begin()->second;
VLOG(3) << "Is Input Ready in gen step " << gen_step_to_scope_id_to_finish_flag_.begin()->first;
}
for (int64_t i = 0; i < node_->max_run_times(); ++i) {
bool flag = true;
Expand All @@ -184,18 +185,26 @@ bool ComputeInterceptor::IsInputReady() {
flag = flag && (ready_size_map.at(i) != 0);
}
if (flag) {
for (auto iter : scope_id_to_finish_flag) {
if (iter.first == i) {
break;
} else if (!iter.second) {
VLOG(3) << "The previous scope is not ready, waiting for the "
"previous scope "
<< iter.first;
return false;
if (scope_id_to_finish_flag.empty()) {
cur_scope_id_ = i;
return true;
} else if (scope_id_to_finish_flag.find(i) != scope_id_to_finish_flag.end()) {
for (auto iter : scope_id_to_finish_flag) {
if (iter.first == i) {
break;
} else if (!iter.second) {
VLOG(3) << "The previous scope is not ready, waiting for the "
"previous scope "
<< iter.first << " in gen_step " << gen_step_to_scope_id_to_finish_flag_.begin()->first;
return false;
}
}
cur_scope_id_ = i;
return true;
} else {
VLOG(3) << "Interceptor " << GetInterceptorId() << " in scope " << i
<< " is larger than gen_step " << gen_step_to_scope_id_to_finish_flag_.begin()->first;
}
cur_scope_id_ = i;
return true;
} else {
VLOG(3) << "Interceptor " << GetInterceptorId() << " in scope " << i
<< "'s upstreams aren't all ready.";
Expand Down Expand Up @@ -346,6 +355,8 @@ void ComputeInterceptor::Run() {

if (!gen_step_to_scope_id_to_finish_flag_.empty()) {
auto iter = gen_step_to_scope_id_to_finish_flag_.begin();
VLOG(3) << "id=" << GetInterceptorId()
<< " ComputeInterceptor running in scope " << cur_scope_id_ << " with gen_step " << iter->first;
auto& scope_id_to_finish_flag = iter->second;
PADDLE_ENFORCE_NE(
scope_id_to_finish_flag.find(cur_scope_id_),
Expand Down
17 changes: 17 additions & 0 deletions paddle/fluid/platform/flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1020,3 +1020,20 @@ PADDLE_DEFINE_EXPORTED_bool(
PADDLE_DEFINE_EXPORTED_string(jit_engine_type,
"Predictor",
"Choose default funciton type in JitLayer.");



/**
* Executor debug FLAG
* Name: FLAGS_fleetexecutor_debug_mode
* Since Version: 2.5
* Value Range: bool
* default=False
* Example:
* Note:
* FLAGS_fleetexecutor_debug_mode == 1, enter in debug mode
*/
PADDLE_DEFINE_EXPORTED_bool(fleetexecutor_debug_mode,
false,
"Enter in FleetExecutor debug mode.");