Skip to content

Commit 62ec644

Browse files
authored
[psgpu]fix pipe bug:save and pull overlap; test=develop (#37233)
1 parent f29a3c6 commit 62ec644

File tree

2 files changed

+31
-37
lines changed

2 files changed

+31
-37
lines changed

paddle/fluid/framework/fleet/ps_gpu_wrapper.cc

Lines changed: 30 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -490,9 +490,8 @@ void PSGPUWrapper::LoadIntoMemory(bool is_shuffle) {
490490

491491
void PSGPUWrapper::start_build_thread() {
492492
running_ = true;
493-
VLOG(3) << "start build CPU&GPU ps thread.";
493+
VLOG(3) << "start build CPU ps thread.";
494494
pre_build_threads_ = std::thread([this] { pre_build_thread(); });
495-
build_threads_ = std::thread([this] { build_thread(); });
496495
}
497496

498497
void PSGPUWrapper::pre_build_thread() {
@@ -515,30 +514,28 @@ void PSGPUWrapper::pre_build_thread() {
515514
VLOG(3) << "build cpu thread end";
516515
}
517516

518-
void PSGPUWrapper::build_thread() {
519-
// build: build_pull + build_gputask
520-
while (running_) {
521-
std::shared_ptr<HeterContext> gpu_task = nullptr;
522-
if (!gpu_free_channel_->Get(gpu_task)) {
523-
continue;
524-
}
525-
if (!buildcpu_ready_channel_->Get(gpu_task)) {
526-
continue;
527-
}
528-
VLOG(3) << "thread BuildGPUTask start.";
529-
platform::Timer timer;
530-
timer.Start();
531-
BuildPull(gpu_task);
532-
timer.Pause();
533-
timer.Start();
534-
BuildGPUTask(gpu_task);
535-
timer.Pause();
536-
VLOG(1) << "thread BuildGPUTask end, cost time: " << timer.ElapsedSec()
537-
<< "s";
538-
539-
train_ready_channel_->Put(gpu_task);
517+
void PSGPUWrapper::build_task() {
518+
// build_task: build_pull + build_gputask
519+
std::shared_ptr<HeterContext> gpu_task = nullptr;
520+
// train end, gpu free
521+
if (!gpu_free_channel_->Get(gpu_task)) {
522+
return;
523+
}
524+
// ins and pre_build end
525+
if (!buildcpu_ready_channel_->Get(gpu_task)) {
526+
return;
540527
}
541-
VLOG(3) << "build gpu thread end";
528+
529+
VLOG(1) << "BuildPull start.";
530+
platform::Timer timer;
531+
timer.Start();
532+
BuildPull(gpu_task);
533+
BuildGPUTask(gpu_task);
534+
timer.Pause();
535+
VLOG(1) << "BuildPull + BuildGPUTask end, cost time: " << timer.ElapsedSec()
536+
<< "s";
537+
538+
current_task_ = gpu_task;
542539
}
543540

544541
void PSGPUWrapper::BeginPass() {
@@ -548,11 +545,15 @@ void PSGPUWrapper::BeginPass() {
548545
PADDLE_THROW(
549546
platform::errors::Fatal("[BeginPass] current task is not ended."));
550547
}
551-
// load+build done
552-
if (!train_ready_channel_->Get(current_task_)) {
553-
PADDLE_THROW(platform::errors::Fatal("train_ready_channel_ failed."));
554-
}
548+
549+
build_task();
555550
timer.Pause();
551+
552+
if (current_task_ == nullptr) {
553+
PADDLE_THROW(platform::errors::Fatal(
554+
"[BeginPass] after build_task, current task is not null."));
555+
}
556+
556557
VLOG(1) << "BeginPass end, cost time: " << timer.ElapsedSec() << "s";
557558
}
558559

paddle/fluid/framework/fleet/ps_gpu_wrapper.h

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ class PSGPUWrapper {
9191
void EndPass();
9292
void start_build_thread();
9393
void pre_build_thread();
94-
void build_thread();
94+
void build_task();
9595

9696
void Finalize() {
9797
VLOG(3) << "PSGPUWrapper Begin Finalize.";
@@ -101,7 +101,6 @@ class PSGPUWrapper {
101101
data_ready_channel_->Close();
102102
buildcpu_ready_channel_->Close();
103103
gpu_free_channel_->Close();
104-
train_ready_channel_->Close();
105104
running_ = false;
106105
VLOG(3) << "begin stop pre_build_threads_";
107106
pre_build_threads_.join();
@@ -169,8 +168,6 @@ class PSGPUWrapper {
169168
buildcpu_ready_channel_->SetCapacity(3);
170169
gpu_free_channel_->Open();
171170
gpu_free_channel_->SetCapacity(1);
172-
train_ready_channel_->Open();
173-
train_ready_channel_->SetCapacity(1);
174171

175172
current_task_ = nullptr;
176173
gpu_free_channel_->Put(current_task_);
@@ -306,10 +303,6 @@ class PSGPUWrapper {
306303
paddle::framework::ChannelObject<std::shared_ptr<HeterContext>>>
307304
gpu_free_channel_ =
308305
paddle::framework::MakeChannel<std::shared_ptr<HeterContext>>();
309-
std::shared_ptr<
310-
paddle::framework::ChannelObject<std::shared_ptr<HeterContext>>>
311-
train_ready_channel_ =
312-
paddle::framework::MakeChannel<std::shared_ptr<HeterContext>>();
313306
std::shared_ptr<HeterContext> current_task_ = nullptr;
314307
std::thread pre_build_threads_;
315308
std::thread build_threads_;

0 commit comments

Comments
 (0)