@@ -40,8 +40,7 @@ namespace framework {
4040std::shared_ptr<PSGPUWrapper> PSGPUWrapper::s_instance_ = NULL ;
4141bool PSGPUWrapper::is_initialized_ = false ;
4242
43- void PSGPUWrapper::BuildTask (std::shared_ptr<HeterContext> gpu_task,
44- uint64_t table_id, int feature_dim) {
43+ void PSGPUWrapper::BuildTask (std::shared_ptr<HeterContext> gpu_task) {
4544 VLOG (3 ) << " PSGPUWrapper::BuildGPUPSTask begin" ;
4645 platform::Timer timeline;
4746 timeline.Start ();
@@ -68,8 +67,6 @@ void PSGPUWrapper::BuildTask(std::shared_ptr<HeterContext> gpu_task,
6867 thread_keys_.resize (thread_keys_thread_num_);
6968 for (int i = 0 ; i < thread_keys_thread_num_; i++) {
7069 thread_keys_[i].resize (thread_keys_shard_num_);
71- for (int j = 0 ; j < thread_keys_shard_num_; j++) {
72- }
7370 }
7471 const std::deque<Record>& vec_data = input_channel->GetData ();
7572 size_t total_len = vec_data.size ();
@@ -139,17 +136,16 @@ void PSGPUWrapper::BuildTask(std::shared_ptr<HeterContext> gpu_task,
139136 local_ptr[i].resize (local_keys[i].size ());
140137 }
141138 timeline.Start ();
142- auto ptl_func = [this , &local_keys, &local_ptr, &table_id,
143- &fleet_ptr](int i) {
139+ auto ptl_func = [this , &local_keys, &local_ptr, &fleet_ptr](int i) {
144140 size_t key_size = local_keys[i].size ();
145141#ifdef PADDLE_WITH_PSLIB
146142 auto tt = fleet_ptr->pslib_ptr_ ->_worker_ptr ->pull_sparse_ptr (
147- reinterpret_cast <char **>(local_ptr[i].data ()), table_id ,
143+ reinterpret_cast <char **>(local_ptr[i].data ()), this -> table_id_ ,
148144 local_keys[i].data (), key_size);
149145#endif
150146#ifdef PADDLE_WITH_PSCORE
151147 auto tt = fleet_ptr->_worker_ptr ->pull_sparse_ptr (
152- reinterpret_cast <char **>(local_ptr[i].data ()), table_id ,
148+ reinterpret_cast <char **>(local_ptr[i].data ()), this -> table_id_ ,
153149 local_keys[i].data (), key_size);
154150#endif
155151 tt.wait ();
@@ -255,7 +251,7 @@ void PSGPUWrapper::BuildTask(std::shared_ptr<HeterContext> gpu_task,
255251 }
256252 }
257253#endif
258- VLOG (1 ) << " GpuPs build hbmps done" ;
254+ VLOG (3 ) << " GpuPs build hbmps done" ;
259255
260256 device_mutex[dev]->unlock ();
261257 }
@@ -272,11 +268,8 @@ void PSGPUWrapper::BuildTask(std::shared_ptr<HeterContext> gpu_task,
272268 << " seconds." ;
273269}
274270
275- void PSGPUWrapper::BuildGPUPS ( uint64_t table_id, int feature_dim ) {
271+ void PSGPUWrapper::BuildGPUTask (std::shared_ptr<HeterContext> gpu_task ) {
276272 int device_num = heter_devices_.size ();
277- std::shared_ptr<HeterContext> gpu_task = gpu_task_pool_.Get ();
278- gpu_task->Reset ();
279- BuildTask (gpu_task, table_id, feature_dim);
280273 platform::Timer timeline;
281274 timeline.Start ();
282275
@@ -291,15 +284,21 @@ void PSGPUWrapper::BuildGPUPS(uint64_t table_id, int feature_dim) {
291284 delete HeterPs_;
292285 HeterPs_ = nullptr ;
293286 }
287+ if (size_max <= 0 ) {
288+ VLOG (1 ) << " Skip build gpu ps cause feasign nums = " << size_max;
289+ return ;
290+ }
294291 std::vector<std::thread> threads (device_num);
295292 HeterPs_ = HeterPsBase::get_instance (size_max, resource_);
296293 HeterPs_->set_nccl_comm_and_size (inner_comms_, inter_comms_, node_size_);
297294 auto build_func = [this , &gpu_task, &feature_keys_count](int i) {
298- std::cout << " building table: " << i << std::endl ;
295+ VLOG ( 3 ) << " building table: " << i;
299296 this ->HeterPs_ ->build_ps (i, gpu_task->device_keys_ [i].data (),
300297 gpu_task->device_values_ [i].data (),
301298 feature_keys_count[i], 500000 , 2 );
302- HeterPs_->show_one_table (i);
299+ if (feature_keys_count[i] > 0 ) {
300+ HeterPs_->show_one_table (i);
301+ }
303302 };
304303 for (size_t i = 0 ; i < threads.size (); i++) {
305304 threads[i] = std::thread (build_func, i);
@@ -310,7 +309,109 @@ void PSGPUWrapper::BuildGPUPS(uint64_t table_id, int feature_dim) {
310309 timeline.Pause ();
311310 VLOG (1 ) << " GpuPs build table total costs: " << timeline.ElapsedSec ()
312311 << " s." ;
313- gpu_task_pool_.Push (gpu_task);
312+ }
313+
314+ void PSGPUWrapper::LoadIntoMemory (bool is_shuffle) {
315+ platform::Timer timer;
316+ VLOG (3 ) << " Begin LoadIntoMemory(), dataset[" << dataset_ << " ]" ;
317+ timer.Start ();
318+ dataset_->LoadIntoMemory ();
319+ timer.Pause ();
320+ VLOG (0 ) << " LoadIntoMemory cost: " << timer.ElapsedSec () << " s" ;
321+
322+ // local shuffle
323+ if (is_shuffle) {
324+ dataset_->LocalShuffle ();
325+ }
326+
327+ std::shared_ptr<HeterContext> gpu_task = gpu_task_pool_.Get ();
328+ gpu_task->Reset ();
329+ data_ready_channel_->Put (gpu_task);
330+ VLOG (3 ) << " End LoadIntoMemory(), dataset[" << dataset_ << " ]" ;
331+ }
332+
333+ void PSGPUWrapper::start_build_thread () {
334+ running_ = true ;
335+ VLOG (3 ) << " start build CPU&GPU ps thread." ;
336+ build_cpu_threads_ = std::thread ([this ] { build_cpu_thread (); });
337+ build_gpu_threads_ = std::thread ([this ] { build_gpu_thread (); });
338+ }
339+
340+ void PSGPUWrapper::build_cpu_thread () {
341+ while (running_) {
342+ std::shared_ptr<HeterContext> gpu_task = nullptr ;
343+ if (!data_ready_channel_->Get (gpu_task)) {
344+ continue ;
345+ }
346+ VLOG (3 ) << " thread BuildTask start." ;
347+ platform::Timer timer;
348+ timer.Start ();
349+ // build cpu ps data process
350+ BuildTask (gpu_task);
351+ timer.Pause ();
352+ VLOG (1 ) << " thread BuildTask end, cost time: " << timer.ElapsedSec () << " s" ;
353+ buildcpu_ready_channel_->Put (gpu_task);
354+ }
355+ VLOG (3 ) << " build cpu thread end" ;
356+ }
357+
358+ void PSGPUWrapper::build_gpu_thread () {
359+ while (running_) {
360+ std::shared_ptr<HeterContext> gpu_task = nullptr ;
361+ if (!gpu_free_channel_->Get (gpu_task)) {
362+ continue ;
363+ }
364+ if (!buildcpu_ready_channel_->Get (gpu_task)) {
365+ continue ;
366+ }
367+ VLOG (3 ) << " thread BuildGPUTask start." ;
368+ platform::Timer timer;
369+ timer.Start ();
370+ BuildGPUTask (gpu_task);
371+ timer.Pause ();
372+ VLOG (1 ) << " thread BuildGPUTask end, cost time: " << timer.ElapsedSec ()
373+ << " s" ;
374+
375+ gpu_task_pool_.Push (gpu_task);
376+ train_ready_channel_->Put (gpu_task);
377+ }
378+ VLOG (3 ) << " build gpu thread end" ;
379+ }
380+
381+ void PSGPUWrapper::BeginPass () {
382+ platform::Timer timer;
383+ timer.Start ();
384+ if (current_task_) {
385+ PADDLE_THROW (
386+ platform::errors::Fatal (" [BeginPass] current task is not ended." ));
387+ }
388+ // load+build done
389+ if (!train_ready_channel_->Get (current_task_)) {
390+ PADDLE_THROW (platform::errors::Fatal (" train_ready_channel_ failed." ));
391+ }
392+ timer.Pause ();
393+ VLOG (1 ) << " BeginPass end, cost time: " << timer.ElapsedSec () << " s" ;
394+ }
395+
396+ void PSGPUWrapper::EndPass () {
397+ if (!current_task_) {
398+ PADDLE_THROW (
399+ platform::errors::Fatal (" [EndPass] current task has been ended." ));
400+ }
401+ platform::Timer timer;
402+ timer.Start ();
403+ size_t keysize_max = 0 ;
404+ // in case of feasign_num = 0, skip dump_to_cpu
405+ for (size_t i = 0 ; i < heter_devices_.size (); i++) {
406+ keysize_max = std::max (keysize_max, current_task_->device_keys_ [i].size ());
407+ }
408+ if (keysize_max != 0 ) {
409+ HeterPs_->end_pass ();
410+ }
411+ current_task_ = nullptr ;
412+ gpu_free_channel_->Put (current_task_);
413+ timer.Pause ();
414+ VLOG (1 ) << " EndPass end, cost time: " << timer.ElapsedSec () << " s" ;
314415}
315416
316417void PSGPUWrapper::PullSparse (const paddle::platform::Place& place,
0 commit comments