@@ -10,16 +10,18 @@ See the License for the specific language governing permissions and
1010limitations under the License. */
1111
1212#include " paddle/fluid/framework/fleet/gloo_wrapper.h"
13+ #include < thread> // NOLINT
1314#include < vector>
1415#include " paddle/fluid/framework/io/fs.h"
1516#include " paddle/fluid/platform/errors.h"
17+ #include " paddle/fluid/string/string_helper.h"
1618
1719namespace gloo {
1820namespace rendezvous {
1921
2022HdfsStore::HdfsStore (const std::string& path) {
2123 path_ = path;
22- wait_sleep_ms_ = 3000 ;
24+ wait_sleep_ms_ = 10000 ;
2325 wait_timeout_ = std::chrono::seconds (999999999 );
2426 retry_times_ = 100 ;
2527}
@@ -35,49 +37,86 @@ void HdfsStore::set(const std::string& key, const std::vector<char>& data) {
3537 }
3638 int err_no = 0 ;
3739 for (int i = 1 ; i <= retry_times_; ++i) {
40+ err_no = 0 ;
3841 std::shared_ptr<FILE> fp =
3942 paddle::framework::fs_open_write (tmp, &err_no, " " );
40- if (err_no != 0 ) {
41- VLOG (0 ) << " fs_open_write failed, retry times " << i << " err no "
42- << err_no;
43- fp.reset ();
44- sleep (wait_sleep_ms_ / 1000 );
45- continue ;
46- }
4743 size_t write_count = fwrite_unlocked (data.data (), 1 , data.size (), fp.get ());
4844 if (write_count != data.size ()) {
4945 VLOG (0 ) << " fwrite_unlocked failed, retry times " << i << " write_count "
5046 << write_count << " data.size() " << data.size ();
51- fp.reset ();
52- sleep (2 );
53- continue ;
47+ err_no = -1 ;
5448 }
5549 fp.reset ();
56- break ;
50+ if (err_no != 0 ) {
51+ VLOG (0 ) << " fs_open_write failed, retry times " << i << " err no "
52+ << err_no;
53+ sleep (wait_sleep_ms_ / 1000 );
54+ paddle::framework::fs_remove (tmp);
55+ if (i == retry_times_) {
56+ VLOG (0 ) << " fs_open_write failed, retry times reaches limit" ;
57+ PADDLE_THROW (platform::errors::PreconditionNotMet (
58+ " fs_open_write failed, retry times reaches"
59+ " limit " ,
60+ retry_times_));
61+ }
62+ } else {
63+ break ;
64+ }
5765 }
5866 paddle::framework::fs_mv (tmp, path);
5967#endif
6068}
6169
70+ #ifdef PADDLE_WITH_GLOO
71+ int retry_do_func (std::function<int (void )> func, uint32_t max_try_time,
72+ uint32_t retry_interval_ms) {
73+ for (uint32_t i = 0 ; i < max_try_time; ++i) {
74+ if (func () == 0 ) {
75+ return 0 ;
76+ }
77+ #ifdef _LINUX
78+ usleep (retry_interval_ms * 1000 );
79+ #endif
80+ }
81+ return -1 ;
82+ }
83+ #endif
84+
6285std::vector<char > HdfsStore::get (const std::string& key) {
6386 auto path = ObjectPath (key);
6487 std::vector<char > result;
6588#ifdef PADDLE_WITH_GLOO
6689 // block until key is set
6790 wait ({key});
68- bool is_exists = paddle::framework::fs_exists (path);
91+ int ret = retry_do_func (
92+ [&path]() { return paddle::framework::fs_exists (path) ? 0 : -1 ; }, 5 ,
93+ wait_sleep_ms_);
94+ bool is_exists = (ret == 0 );
6995 PADDLE_ENFORCE_EQ (is_exists, true ,
7096 paddle::platform::errors::NotFound (
7197 " HdfsStore::get, path not exists: " + path));
72- int err_no = 0 ;
73- std::shared_ptr<FILE> fp = paddle::framework::fs_open_read (path, &err_no, " " );
74- char buffer = ' \0 ' ;
75- size_t read_count = 0 ;
76- while (fread (&buffer, 1 , 1 , fp.get ()) == 1 ) {
77- ++read_count;
78- result.push_back (buffer);
79- }
80- VLOG (3 ) << " HdfsStore::get read_count " << read_count;
98+
99+ int read_status = retry_do_func (
100+ [&path, &result]() {
101+ result.clear ();
102+ int err_no = 0 ;
103+ {
104+ std::shared_ptr<FILE> fp =
105+ paddle::framework::fs_open_read (path, &err_no, " " );
106+ char buffer = ' \0 ' ;
107+ size_t read_count = 0 ;
108+ while (fread (&buffer, 1 , 1 , fp.get ()) == 1 ) {
109+ ++read_count;
110+ result.push_back (buffer);
111+ }
112+ VLOG (3 ) << " HdfsStore::get read_count " << read_count;
113+ }
114+ return err_no;
115+ },
116+ 5 , wait_sleep_ms_);
117+ PADDLE_ENFORCE_EQ (read_status, 0 ,
118+ paddle::platform::errors::Fatal (
119+ " HdfsStore::get, path read faied: " + path));
81120#endif
82121 return result;
83122}
@@ -92,22 +131,33 @@ void HdfsStore::wait(const std::vector<std::string>& keys,
92131 const std::chrono::milliseconds&) { // NOLINT
93132#ifdef PADDLE_WITH_GLOO
94133 auto start = std::chrono::steady_clock::now ();
95- while (!Check (keys)) {
134+ std::vector<bool > check_key_status (keys.size (), false );
135+ while (!Check (keys, &check_key_status)) {
96136 auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(
97137 std::chrono::steady_clock::now () - start);
98138 if (wait_timeout_ != gloo::kNoTimeout && elapsed > wait_timeout_) {
99- PADDLE_ENFORCE_EQ (0 , 1 , paddle::platform::errors::ExecutionTimeout (
100- " HdfsStore::wait, Wait timeout for key(s): " +
101- ::gloo::MakeString (keys)));
139+ int32_t last_check_rank = -1 ;
140+ for (size_t i = 0 ; i < check_key_status.size (); ++i) {
141+ if (!check_key_status[i]) {
142+ last_check_rank = i;
143+ break ;
144+ }
145+ }
146+ PADDLE_THROW (platform::errors::ExecutionTimeout (
147+ " TIMEOUT self_rank = %d pair_rank = %d" , self_rank_,
148+ last_check_rank));
102149 }
103150 std::this_thread::sleep_for (std::chrono::milliseconds (wait_sleep_ms_));
104151 }
105152#endif
106153}
107154
155+ void HdfsStore::SetTimeoutSeconds (int timeout_seconds) {
156+ wait_timeout_ = std::chrono::seconds (timeout_seconds);
157+ }
158+
108159std::string HdfsStore::EncodeName (const std::string& name) {
109- thread_local std::hash<std::string> hash_func;
110- return std::to_string (hash_func (name));
160+ return ::paddle::string::erase_spaces (name);
111161}
112162
113163std::string HdfsStore::TmpPath (const std::string& name) {
@@ -118,50 +168,124 @@ std::string HdfsStore::ObjectPath(const std::string& name) {
118168 return path_ + " /" + EncodeName (name);
119169}
120170
121- bool HdfsStore::Check (const std::vector<std::string>& keys) {
171+ bool HdfsStore::Check (const std::vector<std::string>& keys,
172+ std::vector<bool >* keys_check_status) {
122173#ifdef PADDLE_WITH_GLOO
174+ bool ret = true ;
123175 std::vector<std::string> paths;
124176 for (const auto & key : keys) {
125177 paths.push_back (ObjectPath (key));
126178 }
127- for (const auto & path : paths) {
179+ for (size_t i = 0 ; i < paths.size (); ++i) {
180+ if ((*keys_check_status)[i]) {
181+ continue ;
182+ }
183+ const auto & path = paths[i];
128184 bool is_exists = paddle::framework::fs_exists (path);
129185 VLOG (3 ) << " HdfsStore::Check " << is_exists << " path " << path;
130186 if (!is_exists) {
131- return false ;
187+ ret = false ;
132188 }
189+ (*keys_check_status)[i] = is_exists;
133190 }
191+ return ret;
192+ #else
193+ VLOG (0 ) << " HdfsStore::Check does nothing when no gloo" ;
134194#endif
135195 return true ;
136196}
137197
198+ #ifdef PADDLE_WITH_GLOO
199+ void ParallelConnectContext::connectFullMesh (
200+ Store& store, std::shared_ptr<transport::Device>& dev) {
201+ std::vector<char > allBytes;
202+ // Create pairs
203+ auto transportContext = dev->createContext (rank, size);
204+ transportContext->setTimeout (getTimeout ());
205+ for (int i = 0 ; i < size; i++) {
206+ if (i == rank) {
207+ continue ;
208+ }
209+ auto & pair = transportContext->createPair (i);
210+ auto addrBytes = pair->address ().bytes ();
211+ allBytes.insert (allBytes.end (), addrBytes.begin (), addrBytes.end ());
212+ }
213+ std::ostringstream storeKey;
214+ storeKey << rank;
215+ store.set (storeKey.str (), allBytes);
216+
217+ std::vector<std::shared_ptr<std::thread>> connect_threads (thread_num_);
218+ // Connect every pair
219+ for (uint32_t i = 0 ; i < connect_threads.size (); ++i) {
220+ connect_threads[i].reset (new std::thread (
221+ [&store, &transportContext, this ](size_t thread_idx,
222+ size_t thread_num) -> void {
223+ for (int i = thread_idx; i < size; i += thread_num) {
224+ if (i == rank) {
225+ continue ;
226+ }
227+ // Wait for address of other side of this pair to become available
228+ std::string key = std::to_string (i);
229+ store.wait ({key}, getTimeout ());
230+ // Connect to other side of this pair
231+ auto allAddrs = store.get (key);
232+ auto addr = extractAddress (allAddrs, i);
233+ transportContext->getPair (i)->connect (addr);
234+ }
235+ },
236+ i, connect_threads.size ()));
237+ }
238+ for (uint32_t i = 0 ; i < connect_threads.size (); ++i) {
239+ connect_threads[i]->join ();
240+ }
241+ device_ = dev;
242+ transportContext_ = std::move (transportContext);
243+ }
244+ #endif
138245} // namespace rendezvous
139246} // namespace gloo
140247
141248namespace paddle {
142249namespace framework {
143250
144- void GlooWrapper::Init (int rank, int size, const std::string& path,
145- const std::string& fs_name, const std::string& fs_ugi,
146- const std::string& iface, const std::string& prefix) {
251+ void GlooWrapper::Init () {
147252 if (is_initialized_) {
148253 return ;
149254 }
150- rank_ = rank;
151- size_ = size;
152- std::string cmd = std::string (" ${HADOOP_HOME}/bin/hadoop fs" );
153- cmd += " -D fs.default.name=" + fs_name;
154- cmd += " -D hadoop.job.ugi=" + fs_ugi;
155- paddle::framework::hdfs_set_command (cmd);
156255#ifdef PADDLE_WITH_GLOO
157256 gloo::transport::tcp::attr attr;
158- attr.iface = iface;
159- auto file_store = gloo::rendezvous::HdfsStore (path);
160- auto prefix_store = gloo::rendezvous::PrefixStore (prefix, file_store);
257+ attr.iface = iface_;
258+ std::shared_ptr<gloo::rendezvous::HdfsStore> file_store = nullptr ;
259+ std::shared_ptr<gloo::rendezvous::HTTPStore> http_store = nullptr ;
260+ auto context =
261+ std::make_shared<gloo::rendezvous::ParallelConnectContext>(rank_, size_);
262+ context->setTimeout (run_timeout_);
161263 auto dev = gloo::transport::tcp::CreateDevice (attr);
162- auto context = std::make_shared<gloo::rendezvous::Context>(rank, size);
163- context->setTimeout (file_store.wait_timeout_ );
164- context->connectFullMesh (prefix_store, dev);
264+ switch (store_type_) {
265+ case GlooStoreType::HDFS: {
266+ std::string cmd = std::string (" ${HADOOP_HOME}/bin/hadoop fs" );
267+ cmd += " -D fs.default.name=" + hdfs_name_;
268+ cmd += " -D hadoop.job.ugi=" + hdfs_ugi_;
269+ paddle::framework::hdfs_set_command (cmd);
270+ file_store = std::make_shared<gloo::rendezvous::HdfsStore>(hdfs_path_);
271+ file_store->SetTimeoutSeconds (init_timeout_.count ());
272+ auto prefix_store =
273+ std::make_shared<gloo::rendezvous::PrefixStore>(prefix_, *file_store);
274+ context->connectFullMesh (*prefix_store, dev);
275+ break ;
276+ }
277+ case GlooStoreType::HTTP: {
278+ http_store = std::make_shared<gloo::rendezvous::HTTPStore>(
279+ http_ip_, http_port_, prefix_ + " _" + http_scope_, rank_);
280+ http_store->SetTimeoutSeconds (init_timeout_.count ());
281+ context->connectFullMesh (*http_store, dev);
282+ http_store->Finalize ();
283+ break ;
284+ }
285+ default :
286+ LOG (ERROR) << " unknown store type " << store_type_;
287+ exit (-1 );
288+ }
165289 context_ = std::move (context);
166290#endif
167291 is_initialized_ = true ;
@@ -170,6 +294,9 @@ void GlooWrapper::Init(int rank, int size, const std::string& path,
170294template std::vector<int64_t > GlooWrapper::AllReduce<int64_t >(
171295 std::vector<int64_t >& sendbuf, // NOLINT
172296 const std::string& mode);
297+ template std::vector<float > GlooWrapper::AllReduce<float >(
298+ std::vector<float >& sendbuf, // NOLINT
299+ const std::string& mode);
173300template std::vector<double > GlooWrapper::AllReduce<double >(
174301 std::vector<double >& sendbuf, // NOLINT
175302 const std::string& mode);
@@ -180,6 +307,8 @@ template std::vector<int64_t> GlooWrapper::AllGather<int64_t>(
180307 int64_t & input); // NOLINT
181308template std::vector<uint64_t > GlooWrapper::AllGather<uint64_t >(
182309 uint64_t & input); // NOLINT
310+ template std::vector<float > GlooWrapper::AllGather<float >(
311+ float & input); // NOLINT
183312template std::vector<double > GlooWrapper::AllGather<double >(
184313 double & input); // NOLINT
185314
0 commit comments