@@ -490,9 +490,8 @@ void PSGPUWrapper::LoadIntoMemory(bool is_shuffle) {
490490
491491void PSGPUWrapper::start_build_thread () {
492492 running_ = true ;
493- VLOG (3 ) << " start build CPU&GPU ps thread." ;
493+ VLOG (3 ) << " start build CPU ps thread." ;
494494 pre_build_threads_ = std::thread ([this ] { pre_build_thread (); });
495- build_threads_ = std::thread ([this ] { build_thread (); });
496495}
497496
498497void PSGPUWrapper::pre_build_thread () {
@@ -515,30 +514,28 @@ void PSGPUWrapper::pre_build_thread() {
515514 VLOG (3 ) << " build cpu thread end" ;
516515}
517516
518- void PSGPUWrapper::build_thread () {
519- // build: build_pull + build_gputask
520- while (running_) {
521- std::shared_ptr<HeterContext> gpu_task = nullptr ;
522- if (!gpu_free_channel_->Get (gpu_task)) {
523- continue ;
524- }
525- if (!buildcpu_ready_channel_->Get (gpu_task)) {
526- continue ;
527- }
528- VLOG (3 ) << " thread BuildGPUTask start." ;
529- platform::Timer timer;
530- timer.Start ();
531- BuildPull (gpu_task);
532- timer.Pause ();
533- timer.Start ();
534- BuildGPUTask (gpu_task);
535- timer.Pause ();
536- VLOG (1 ) << " thread BuildGPUTask end, cost time: " << timer.ElapsedSec ()
537- << " s" ;
538-
539- train_ready_channel_->Put (gpu_task);
517+ void PSGPUWrapper::build_task () {
518+ // build_task: build_pull + build_gputask
519+ std::shared_ptr<HeterContext> gpu_task = nullptr ;
520+ // train end, gpu free
521+ if (!gpu_free_channel_->Get (gpu_task)) {
522+ return ;
523+ }
524+ // ins and pre_build end
525+ if (!buildcpu_ready_channel_->Get (gpu_task)) {
526+ return ;
540527 }
541- VLOG (3 ) << " build gpu thread end" ;
528+
529+ VLOG (1 ) << " BuildPull start." ;
530+ platform::Timer timer;
531+ timer.Start ();
532+ BuildPull (gpu_task);
533+ BuildGPUTask (gpu_task);
534+ timer.Pause ();
535+ VLOG (1 ) << " BuildPull + BuildGPUTask end, cost time: " << timer.ElapsedSec ()
536+ << " s" ;
537+
538+ current_task_ = gpu_task;
542539}
543540
544541void PSGPUWrapper::BeginPass () {
@@ -548,11 +545,15 @@ void PSGPUWrapper::BeginPass() {
548545 PADDLE_THROW (
549546 platform::errors::Fatal (" [BeginPass] current task is not ended." ));
550547 }
551- // load+build done
552- if (!train_ready_channel_->Get (current_task_)) {
553- PADDLE_THROW (platform::errors::Fatal (" train_ready_channel_ failed." ));
554- }
548+
549+ build_task ();
555550 timer.Pause ();
551+
552+ if (current_task_ == nullptr ) {
553+ PADDLE_THROW (platform::errors::Fatal (
554+ " [BeginPass] after build_task, current task is not null." ));
555+ }
556+
556557 VLOG (1 ) << " BeginPass end, cost time: " << timer.ElapsedSec () << " s" ;
557558}
558559
0 commit comments