Skip to content

Commit 80a8a1e

Browse files
committed
optimize overlap between steps
1 parent 92c2dcb commit 80a8a1e

11 files changed

Lines changed: 291 additions & 126 deletions

File tree

paddle/fluid/distributed/fleet_executor/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ cc_library(
5151
op_registry
5252
executor_gc_helper
5353
gflags
54+
flags
5455
glog
5556
${BRPC_DEPS})
5657

paddle/fluid/distributed/fleet_executor/carrier.cc

Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <algorithm>
1818
#include <vector>
1919

20+
#include "gflags/gflags.h"
2021
#include "paddle/fluid/distributed/fleet_executor/global.h"
2122
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
2223
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
@@ -28,6 +29,8 @@
2829
#include "paddle/fluid/framework/variable.h"
2930
#include "paddle/fluid/framework/variable_helper.h"
3031

32+
DECLARE_bool(fleetexecutor_debug_mode);
33+
3134
namespace paddle {
3235
namespace distributed {
3336

@@ -48,6 +51,72 @@ void Carrier::Init(
4851
thread_num_ = 1;
4952
thread_pool_.SetThreadNum(thread_num_);
5053
thread_pool_.Start();
54+
55+
if (FLAGS_fleetexecutor_debug_mode) {
56+
test_thread_ = std::thread([this]() { loop_to_send_msg(); });
57+
cache_begin_ == std::chrono::steady_clock::now();
58+
}
59+
}
60+
61+
void Carrier::loop_to_send_msg() {
62+
// VLOG(3) << "loop_send_msg loop now";
63+
while (1) {
64+
while (1) {
65+
int q_size = 0;
66+
std::chrono::time_point<std::chrono::steady_clock> c_begin;
67+
{
68+
std::lock_guard<std::mutex> lock(running_mutex_);
69+
q_size = messages_for_test_.size();
70+
c_begin = cache_begin_;
71+
}
72+
73+
auto now = std::chrono::steady_clock::now();
74+
auto delta =
75+
std::chrono::duration_cast<std::chrono::milliseconds>(now - c_begin)
76+
.count();
77+
78+
if (q_size < 2 && delta < 5000) {
79+
std::this_thread::sleep_for(std::chrono::milliseconds(10));
80+
continue;
81+
} else {
82+
VLOG(3) << "messages_for_test_ q_size:" << q_size << ", delta:" << delta
83+
<< ", will send all msg";
84+
break;
85+
}
86+
}
87+
88+
{
89+
std::lock_guard<std::mutex> lock(running_mutex_);
90+
while (!messages_for_test_.empty()) {
91+
auto msg = messages_for_test_.back();
92+
messages_for_test_.pop_back();
93+
94+
int64_t src_id = msg.src_id();
95+
// TODO(liyurui): compatible solution, will be removed completely in the
96+
// future
97+
if (interceptor_id_to_rank_.find(src_id) ==
98+
interceptor_id_to_rank_.end() &&
99+
src_id == SOURCE_ID) {
100+
src_id = msg.dst_id();
101+
}
102+
int64_t dst_id = msg.dst_id();
103+
int64_t dst_rank = GetRank(dst_id);
104+
105+
VLOG(3) << "Send a cached message from interceptor " << src_id
106+
<< " to interceptor " << dst_id
107+
<< ", which are in different ranks, scope_idx:"
108+
<< msg.scope_idx();
109+
110+
if (!GlobalVal<MessageBus>::Get()->Send(dst_rank, msg)) {
111+
LOG(FATAL) << "send msg error";
112+
}
113+
std::this_thread::sleep_for(std::chrono::milliseconds(2));
114+
}
115+
116+
cache_begin_ = std::chrono::steady_clock::now();
117+
}
118+
}
119+
VLOG(3) << "reset cache_begin_";
51120
}
52121

53122
void Carrier::Init(
@@ -95,6 +164,11 @@ void Carrier::Init(
95164
thread_pool_.SetThreadNum(thread_num_);
96165
thread_pool_.Start();
97166

167+
if (FLAGS_fleetexecutor_debug_mode) {
168+
test_thread_ = std::thread([this]() { loop_to_send_msg(); });
169+
cache_begin_ == std::chrono::steady_clock::now();
170+
}
171+
98172
CreateInterceptors();
99173
is_init_ = true;
100174
}
@@ -230,12 +304,39 @@ bool Carrier::Send(const InterceptorMessage& msg) {
230304
VLOG(3) << "Send a message from interceptor " << src_id
231305
<< " to interceptor " << dst_id << ", which are in the same ranks.";
232306
return EnqueueInterceptorMessage(msg);
233-
} else {
307+
}
308+
if (!FLAGS_fleetexecutor_debug_mode) {
309+
VLOG(3) << "Send a message from interceptor " << src_id
310+
<< " to interceptor " << dst_id
311+
<< ", which are in different ranks.";
312+
return GlobalVal<MessageBus>::Get()->Send(dst_rank, msg);
313+
}
314+
315+
if (msg.message_type() != DATA_IS_READY) {
234316
VLOG(3) << "Send a message from interceptor " << src_id
235317
<< " to interceptor " << dst_id
236318
<< ", which are in different ranks.";
237319
return GlobalVal<MessageBus>::Get()->Send(dst_rank, msg);
238320
}
321+
322+
{
323+
VLOG(3) << "prepare executor debug";
324+
325+
std::unique_lock<std::mutex> lock(running_mutex_);
326+
if (messages_for_test_.empty()) {
327+
cache_begin_ = std::chrono::steady_clock::now();
328+
// std::time_t now_c =
329+
// std::chrono::system_clock::to_time_t(cache_begin_));
330+
VLOG(3) << "messages_for_test_ empty, reset cache_begin_";
331+
}
332+
333+
VLOG(3) << "Cache message from interceptor " << src_id << " to interceptor "
334+
<< dst_id
335+
<< ", which are in different ranks, scope_idx:" << msg.scope_idx();
336+
messages_for_test_.emplace_back(msg);
337+
}
338+
339+
return true;
239340
}
240341

241342
Interceptor* Carrier::SetInterceptor(int64_t interceptor_id,

paddle/fluid/distributed/fleet_executor/carrier.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,14 @@
1414

1515
#pragma once
1616

17+
#include <chrono>
1718
#include <condition_variable>
1819
#include <memory>
1920
#include <mutex>
21+
#include <queue>
2022
#include <set>
2123
#include <string>
24+
#include <thread>
2225
#include <unordered_map>
2326
#include <vector>
2427

@@ -118,6 +121,12 @@ class Carrier final {
118121
int thread_num_;
119122
TaskLoopThreadPool thread_pool_;
120123
std::unordered_set<int64_t> interceptor_ids_;
124+
125+
std::deque<InterceptorMessage> messages_for_test_;
126+
std::thread test_thread_;
127+
std::chrono::time_point<std::chrono::steady_clock> cache_begin_;
128+
129+
void loop_to_send_msg();
121130
};
122131

123132
} // namespace distributed

0 commit comments

Comments
 (0)