@@ -57,8 +57,6 @@ void PSGPUTrainer::Initialize(const TrainerDesc& trainer_desc,
5757 trainer_desc.downpour_param ().stat_var_names (i));
5858 }
5959 VLOG (3 ) << " going to initialize pull dense worker" ;
60- pull_dense_worker_ = PullDenseWorker::GetInstance ();
61- pull_dense_worker_->Initialize (trainer_desc);
6260 SetDebug (trainer_desc.debug ());
6361 trainer_desc_ = trainer_desc;
6462 workers_.resize (place_num);
@@ -112,15 +110,21 @@ void PSGPUTrainer::InitTrainerEnv(const ProgramDesc& main_program,
112110 }
113111 }
114112 }
113+ for (auto & var : main_program.Block (0 ).AllVars ()) {
114+ if (var->Persistable ()) {
115+ auto it = std::find (need_merge_var_names_.begin (),
116+ need_merge_var_names_.end (), var->Name ());
117+ if (it == need_merge_var_names_.end ()) {
118+ VLOG (2 ) << " train param: " << var->Name ();
119+ trainable_param_.push_back (var->Name ());
120+ }
121+ }
122+ }
115123 place_ = place;
116124 return ;
117125}
118126
119127void PSGPUTrainer::InitOtherEnv (const ProgramDesc& main_program) {
120- pull_dense_worker_->SetRootScope (root_scope_);
121- for (size_t i = 0 ; i < places_.size (); ++i) {
122- pull_dense_worker_->AddThreadScope (workers_[i]->GetThreadScope ());
123- }
124128 VLOG (3 ) << " init other env done." ;
125129}
126130
@@ -141,15 +145,27 @@ Scope* PSGPUTrainer::GetWorkerScope(int thread_id) { return nullptr; }
141145template <typename T>
142146void PSGPUTrainer::MergeToRootScope (LoDTensor* root_tensor, LoDTensor* tensor) {
143147 LoDTensor tmp_root;
144- TensorCopy (*root_tensor, platform::CPUPlace (), &tmp_root);
148+ TensorCopySync (*root_tensor, platform::CPUPlace (), &tmp_root);
145149 T* tmp_root_data = tmp_root.data <T>();
146150 LoDTensor tmp_tensor;
147- TensorCopy (*tensor, platform::CPUPlace (), &tmp_tensor);
151+ TensorCopySync (*tensor, platform::CPUPlace (), &tmp_tensor);
148152 T* data = tmp_tensor.data <T>();
149153 for (int i = 0 ; i < tmp_tensor.numel (); i++) {
150154 tmp_root_data[i] += data[i];
151155 }
152- TensorCopy (tmp_root, platform::CPUPlace (), root_tensor);
156+ TensorCopySync (tmp_root, platform::CPUPlace (), root_tensor);
157+ }
158+
159+ void PSGPUTrainer::MergeDenseParam () {
160+ auto thread_scope = workers_[0 ]->GetThreadScope ();
161+ for (auto & name : trainable_param_) {
162+ VLOG (2 ) << " merge var " << name << " to root scope" ;
163+ Variable* root_var = root_scope_->FindVar (name);
164+ LoDTensor* root_tensor = root_var->GetMutable <LoDTensor>();
165+ Variable* var = thread_scope->FindVar (name);
166+ LoDTensor* tensor = var->GetMutable <LoDTensor>();
167+ TensorCopySync ((*tensor), root_tensor->place (), root_tensor);
168+ }
153169}
154170
155171void PSGPUTrainer::Finalize () {
@@ -187,7 +203,7 @@ void PSGPUTrainer::Finalize() {
187203 _ForEachDataType_ (MergeCallback);
188204 }
189205 }
190- pull_dense_worker_-> MergeDenseParam ();
206+ MergeDenseParam ();
191207 root_scope_->DropKids ();
192208}
193209} // namespace framework
0 commit comments