From 4734f0af58a928745f0d46f8ee7a8d4cbe2503dd Mon Sep 17 00:00:00 2001 From: Thunderbrook Date: Fri, 10 Apr 2020 13:53:19 +0800 Subject: [PATCH 1/8] heter cpu trainer --- paddle/fluid/framework/CMakeLists.txt | 4 +- paddle/fluid/framework/device_worker.h | 306 +++++ .../fluid/framework/device_worker_factory.cc | 1 + paddle/fluid/framework/dist_multi_trainer.cc | 13 +- paddle/fluid/framework/fleet/fleet_wrapper.cc | 216 +++ paddle/fluid/framework/fleet/fleet_wrapper.h | 18 + paddle/fluid/framework/hetercpu_worker.cc | 1212 +++++++++++++++++ paddle/fluid/framework/multi_trainer.cc | 1 + paddle/fluid/framework/trainer.h | 1 + python/paddle/fluid/device_worker.py | 2 +- 10 files changed, 1770 insertions(+), 4 deletions(-) create mode 100644 paddle/fluid/framework/hetercpu_worker.cc diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index d6cce2540188aa..bb12944c2e15f8 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -185,7 +185,7 @@ cc_library(executor_gc_helper SRCS executor_gc_helper.cc DEPS scope proto_desc o if(WITH_DISTRIBUTE) cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc - data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc downpour_worker_opt.cc + data_feed.cc device_worker.cc hogwild_worker.cc hetercpu_worker.cc downpour_worker.cc downpour_worker_opt.cc pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry device_context scope framework_proto trainer_desc_proto glog fs shell fleet_wrapper box_wrapper lodtensor_printer lod_rank_table feed_fetch_method sendrecvop_rpc communicator collective_helper ${GLOB_DISTRIBUTE_DEPS} @@ -195,7 +195,7 @@ set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_CO else() cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc - data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc downpour_worker_opt.cc + data_feed.cc device_worker.cc hogwild_worker.cc hetercpu_worker.cc downpour_worker.cc downpour_worker_opt.cc pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry device_context scope framework_proto data_feed_proto trainer_desc_proto glog lod_rank_table fs shell fleet_wrapper box_wrapper lodtensor_printer feed_fetch_method diff --git a/paddle/fluid/framework/device_worker.h b/paddle/fluid/framework/device_worker.h index f75d7593fe9a50..62a532abb5a074 100644 --- a/paddle/fluid/framework/device_worker.h +++ b/paddle/fluid/framework/device_worker.h @@ -131,6 +131,9 @@ class DeviceWorker { virtual void SetDataFeed(DataFeed* data_feed); virtual void SetNeedDump(bool need_dump_field) {} virtual void SetChannelWriter(ChannelObject* queue) {} + virtual void SetWorkerNum(int num) {}; + virtual void CacheProgram(const ProgramDesc &main_program) {}; + virtual void Schedule(int taskid) {}; virtual void SetPlace(const paddle::platform::Place& place) { place_ = place; } @@ -294,6 +297,309 @@ class DownpourWorkerOpt : public DownpourWorker { uint64_t async_tid_ = 0; }; +enum HeterTaskState { + PULL_SPARSE, + OP_RUN, + XPU, + PUSH_GRAD, + DONE +}; + +class HeterTask { +public: + void Update() { + if (state_ == PULL_SPARSE) { + state_ = OP_RUN; + } + else if (state_ == OP_RUN) { + //state_ = XPU; + //state_ = PUSH_GRAD; + state_ = PUSH_GRAD; + } + else if (state_ == XPU) { + state_ = PUSH_GRAD; + } + else if (state_ == PUSH_GRAD) { + state_ = DONE; + } + } + void Show() { + std::cout << "features size " << features_.size() << std::endl; + for (size_t i = 0; i < features_.size(); ++i) { + std::cout << "features[" << i << "] size " << features_[i].size() << std::endl; + } + } + void PackTask(Scope* scope, int taskid, DataFeed* reader, int cur_batch, const ProgramDesc& program); + + Scope* scope_{nullptr}; + int taskid_; + int cur_batch_; + HeterTaskState state_; + // cache + std::map> features_; + std::map> feature_labels_; + std::map>> feature_values_; + std::map>> feature_grads_; + std::map> sparse_push_keys_; + double total_time; + double read_time; + double pack_time; + double pull_sparse_local_time; +}; + +template +class HeterObjectPool { +public: + std::shared_ptr Get() { + std::lock_guard lock(mutex_); + if (pool_.empty()) { + return std::make_shared(); + } + else { + auto ret = pool_.back(); + pool_.pop_back(); + return ret; + } + } + void Push(std::shared_ptr data) { + std::lock_guard lock(mutex_); + pool_.push_back(std::move(data)); + } + int Size() { + std::lock_guard lock(mutex_); + return pool_.size(); + } +private: + std::vector> pool_; + std::mutex mutex_; +}; + + +template +struct HeterNode { + K key; + T value; + HeterNode *prev; + HeterNode *next; +}; + +template +class HeterList { +public: + HeterList() + : head_(new HeterNode) + , tail_(new HeterNode) { + head_->prev = NULL; + head_->next = tail_; + tail_->prev = head_; + tail_->next = NULL; + size = 0; + cap_ = 1e9; + } + + ~HeterList() { + delete head_; + delete tail_; + } + + void SetCap(int num) { + cap_ = num; + } + + bool TryPut(K& key, T& value) { + std::unique_lock lock(mutex_); + cond_.wait(lock, [this] { return size < cap_; }); + if (task_map_.find(key) != task_map_.end()) { + //std::cout << "try put key=" << key << " false" << std::endl; + return false; + } + else { + HeterNode* node = new HeterNode; + node->key = key; + node->value = value; + map_[node->key] = node; + attach(node); + //std::cout << "try put key=" << key << " true" << std::endl; + return true; + } + } + + bool Put(K& key, T& value) { + std::unique_lock lock(mutex_); + cond_.wait(lock, [this] { return size < cap_; }); + HeterNode* node = new HeterNode; + //std::cout << "put key=" << key << " true" << std::endl; + node->key = key; + node->value = value; + map_[node->key] = node; + attach(node); + return true; + } + + T TryGet(const K &key) { + std::lock_guard lock(mutex_); + auto iter = map_.find(key); + if (iter != map_.end()) { + //std::cout << "try get key=" << key << " true" << std::endl; + HeterNode* node = iter->second; + detach(node); + cond_.notify_one(); + T ret = std::move(node->value); + map_.erase(key); + delete node; + return ret; + } + task_map_.insert(key); + //std::cout << "try get key=" << key << " false" << std::endl; + return nullptr; + } + + T Get(const K &key) { + std::lock_guard lock(mutex_); + auto iter = map_.find(key); + if (iter != map_.end()) { + //std::cout << "get key=" << key << " true" << std::endl; + HeterNode* node = iter->second; + detach(node); + cond_.notify_one(); + T ret = std::move(node->value); + map_.erase(key); + delete node; + return ret; + } + //std::cout << "get key=" << key << " false" << std::endl; + return nullptr; + } + + T Get() { + std::lock_guard lock(mutex_); + HeterNode* node = head_->next; + if (node == tail_) { + //std::cout << "get2 false" << std::endl; + return nullptr; + } + else { + detach(node); + cond_.notify_one(); + T ret = std::move(node->value); + map_.erase(node->key); + //std::cout << "get2 key=" << node->key << " true" << std::endl; + delete node; + return ret; + } + } + + bool Empty() { + std::lock_guard lock(mutex_); + return head_->next == tail_; + } + + int Size() { + std::lock_guard lock(mutex_); + return size; + } + +private: + void detach(HeterNode *node) { + node->prev->next = node->next; + node->next->prev = node->prev; + size--; + } + + void attach(HeterNode *node) { + node->prev = head_; + node->next = head_->next; + head_->next->prev = node; + head_->next = node; + size++; + } + +private: + HeterNode *head_; + HeterNode *tail_; + std::unordered_map*> map_; + std::unordered_set task_map_; + std::mutex mutex_; + std::condition_variable cond_; + int cap_; + int size; +}; + +class HeterCpuWorker : public HogwildWorker { + public: + HeterCpuWorker() {} + virtual ~HeterCpuWorker() {} + virtual void Initialize(const TrainerDesc& desc); + virtual void TrainFiles(); + virtual void TrainFilesWithProfiler(); + virtual void SetNeedDump(bool need_dump_field); + virtual void SetChannelWriter(ChannelObject* queue); + virtual void SetWorkerNum(int num) { worker_num_ = num; } + virtual void CreateThreadParam(const ProgramDesc &main_program); + virtual void Schedule(int taskid); + virtual void JumpContext(std::shared_ptr task); + virtual void CacheProgram(const ProgramDesc &main_program) { + new(&program_) ProgramDesc(main_program); + } + + protected: + std::shared_ptr fleet_ptr_; + std::shared_ptr pull_dense_worker_; + void FillSparseValue(std::shared_ptr task, size_t table_id); + void PushGradients(); + void CollectLabelInfo(std::shared_ptr task, size_t table_id); + void AdjustInsWeight(std::shared_ptr task); + void DumpParam(); + void CopySparseTable(); + void CopyDenseTable(); + void CopyDenseVars(); + + private: + int worker_num_; + ProgramDesc program_; + HeterObjectPool object_pool_; + HeterList> run_queue_; + HeterList> wait_queue_; + bool need_dump_param_; + std::vector dump_param_; + bool need_to_push_dense_; + bool need_dump_field_; + bool dump_slot_; + bool need_to_push_sparse_; + std::vector dump_fields_; + ChannelWriter writer_; + DownpourWorkerParameter param_; + float scale_datanorm_; + // just save the value in param_ for easy access + std::map label_var_name_; + std::map> sparse_key_names_; + std::map> sparse_value_names_; + std::map> sparse_grad_names_; + std::map> dense_value_names_; + std::map> dense_grad_names_; + platform::Place root_place_; + // actually pushed feasign of each table + std::map> sparse_push_keys_; + + // skipped ops + std::vector skip_ops_; + + std::vector<::std::future> push_sparse_status_; + std::vector<::std::future> push_dense_status_; + + // adjust ins weight + AdjustInsWeightConfig adjust_ins_weight_config_; + std::vector nid_show_; + // check nan and inf during training + std::vector check_nan_var_names_; + // copy table + CopyTableConfig copy_table_config_; + std::map table_dependency_; + std::vector> copy_sparse_tables_; + std::vector> copy_dense_tables_; + std::unordered_map> feasign_set_; +}; + #if defined(PADDLE_WITH_NCCL) using ScopeQueue = operators::reader::BlockingQueue; diff --git a/paddle/fluid/framework/device_worker_factory.cc b/paddle/fluid/framework/device_worker_factory.cc index 80e4000c9dc686..cbf306d66216bf 100644 --- a/paddle/fluid/framework/device_worker_factory.cc +++ b/paddle/fluid/framework/device_worker_factory.cc @@ -62,6 +62,7 @@ std::shared_ptr DeviceWorkerFactory::CreateDeviceWorker( REGISTER_DEVICE_WORKER_CLASS(HogwildWorker); REGISTER_DEVICE_WORKER_CLASS(DownpourWorker); REGISTER_DEVICE_WORKER_CLASS(DownpourWorkerOpt); +REGISTER_DEVICE_WORKER_CLASS(HeterCpuWorker); #if defined(PADDLE_WITH_NCCL) REGISTER_DEVICE_WORKER_CLASS(SectionWorker); #endif diff --git a/paddle/fluid/framework/dist_multi_trainer.cc b/paddle/fluid/framework/dist_multi_trainer.cc index 9fe28bddd1f04a..bced867561546c 100644 --- a/paddle/fluid/framework/dist_multi_trainer.cc +++ b/paddle/fluid/framework/dist_multi_trainer.cc @@ -46,7 +46,7 @@ void DistMultiTrainer::Initialize(const TrainerDesc &trainer_desc, dump_file_num_ = trainer_desc.dump_file_num(); const std::vector readers = dataset->GetReaders(); - + RegisterHeterCallback(); thread_num_ = readers.size(); workers_.resize(thread_num_); for (int i = 0; i < trainer_desc.downpour_param().stat_var_names_size(); @@ -62,6 +62,7 @@ void DistMultiTrainer::Initialize(const TrainerDesc &trainer_desc, workers_[i]->SetDataFeed(readers[i]); workers_[i]->Initialize(trainer_desc); workers_[i]->SetNeedDump(need_dump_field_); + workers_[i]->SetWorkerNum(thread_num_); } VLOG(3) << "going to initialize pull dense worker"; @@ -71,6 +72,15 @@ void DistMultiTrainer::Initialize(const TrainerDesc &trainer_desc, SetDebug(trainer_desc.debug()); } +void DistMultiTrainer::RegisterHeterCallback() { + auto fleet_ptr = FleetWrapper::GetInstance(); + fleet_ptr->RegisterHeterCallback( + [this](int worker, int taskid) { + workers_[worker]->Schedule(taskid); + } + ); +} + void DistMultiTrainer::DumpWork(int tid) { #ifdef _LINUX int err_no = 0; @@ -132,6 +142,7 @@ void DistMultiTrainer::InitTrainerEnv(const ProgramDesc &main_program, workers_[i]->SetRootScope(root_scope_); workers_[i]->CreateDeviceResource(main_program); // Program workers_[i]->BindingDataFeedMemory(); + workers_[i]->CacheProgram(main_program); } // Scope* -> thread id, it will be used in push_dense op for (int i = 0; i < thread_num_; ++i) { diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.cc b/paddle/fluid/framework/fleet/fleet_wrapper.cc index 207ce748c1b468..c70588b18e0f9c 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.cc +++ b/paddle/fluid/framework/fleet/fleet_wrapper.cc @@ -154,6 +154,222 @@ void FleetWrapper::CreateClient2ClientConnection() { #endif } +void FleetWrapper::HeterPullSparseVars( + int workerid, + std::shared_ptr task, const uint64_t table_id, + const std::vector& var_names, + int fea_value_dim, + const std::vector& var_emb_names) { +#ifdef PADDLE_WITH_PSLIB + std::vector<::std::future> pull_sparse_status; + pull_sparse_status.resize(0); + auto& scope = *(task->scope_); + auto& fea_keys = (task->features_)[table_id]; + auto& fea_values = (task->feature_values_)[table_id]; + fea_keys.clear(); + for (size_t var_index = 0; var_index < var_names.size(); ++var_index) { + const std::string& name = var_names[var_index]; + Variable* var = scope.FindVar(name); + if (var == nullptr) { + continue; + } + LoDTensor* tensor = var->GetMutable(); + CHECK(tensor != nullptr) << "tensor of var " << name << " is null"; + int64_t* ids = tensor->data(); + size_t len = tensor->numel(); + + // skip slots which do not have embedding + const std::string& emb_name = var_emb_names[var_index]; + Variable* emb_var = scope.FindVar(emb_name); + if (emb_var == nullptr) { + continue; + } + + for (auto i = 0u; i < len; ++i) { + if (ids[i] == 0u) { + continue; + } + fea_keys.push_back(static_cast(ids[i])); + } + } + fea_values.resize(fea_keys.size() + 1); + for (auto& t : fea_values) { + t.resize(fea_value_dim); + } + std::vector pull_result_ptr; + for (auto& t : fea_values) { + pull_result_ptr.push_back(t.data()); + } + auto status = pslib_ptr_->_worker_ptr->heter_pull_sparse(workerid, + pull_result_ptr.data(), table_id, fea_keys.data(), fea_keys.size(), task->taskid_); + pull_sparse_status.push_back(std::move(status)); + //for (auto& t : pull_sparse_status) { + // t.wait(); + // auto status = t.get(); + // if (status != 0) { + // LOG(ERROR) << "fleet pull sparse failed, status[" << status << "]"; + // sleep(sleep_seconds_before_fail_exit_); + // exit(-1); + // } + //} +#endif +} + +void FleetWrapper::HeterPushSparseVars( + std::shared_ptr task, const uint64_t table_id, + const std::vector& sparse_key_names, + const std::vector& sparse_grad_names, const int emb_dim, + std::vector<::std::future>* push_sparse_status, + const bool use_cvm, const bool dump_slot, + const bool no_cvm) { + + auto& scope = *(task->scope_); + int batch_size = task->cur_batch_; + int offset = 2; + int slot_offset = 0; + int grad_dim = emb_dim; + int show_index = 0; + int click_index = 1; + auto& fea_keys = (task->features_)[table_id]; + auto& fea_labels = (task->feature_labels_)[table_id]; + auto& push_values = (task->feature_grads_)[table_id]; + auto& sparse_push_keys = (task->sparse_push_keys_)[table_id]; + + if (use_cvm) { + offset = 0; + grad_dim = emb_dim - 2; + } + if (no_cvm) { + offset = 0; + grad_dim = emb_dim; + } + if (dump_slot) { + slot_offset = 1; + show_index = 1; + click_index = 2; + } + CHECK_GE(grad_dim, 0); + + sparse_push_keys.clear(); + sparse_push_keys.reserve(fea_keys.size() + 1); + push_values.resize(fea_keys.size() + 1); + for (auto& t : push_values) { + t.resize(emb_dim + offset + slot_offset); + } + uint64_t fea_idx = 0u; + for (size_t i = 0; + i < sparse_key_names.size() && i < sparse_grad_names.size(); ++i) { + Variable* var = scope.FindVar(sparse_key_names[i]); + if (var == nullptr) { + continue; + } + LoDTensor* tensor = var->GetMutable(); + if (tensor == nullptr) { + LOG(ERROR) << "tensor of var[" << sparse_key_names[i] << "] is null"; + exit(-1); + } + size_t len = tensor->numel(); + int64_t* ids = tensor->data(); + int slot = 0; + if (dump_slot) { + slot = boost::lexical_cast(sparse_key_names[i]); + } + Variable* g_var = scope.FindVar(sparse_grad_names[i]); + if (g_var == nullptr) { + continue; + } + LoDTensor* g_tensor = g_var->GetMutable(); + if (g_tensor == nullptr) { + LOG(ERROR) << "tensor of var[" << sparse_key_names[i] << "] is null"; + exit(-1); + } + float* g = g_tensor->data(); + + if (scale_sparse_gradient_with_batch_size_ && grad_dim > 0) { + int dim = emb_dim + offset; + Eigen::Map< + Eigen::Matrix> + g_mat(g, g_tensor->numel() / dim, dim); + g_mat.rightCols(grad_dim) *= batch_size; + } + for (auto id_idx = 0u; id_idx < len; ++id_idx) { + if (ids[id_idx] == 0) { + g += emb_dim; + continue; + } + sparse_push_keys.push_back(ids[id_idx]); + CHECK(fea_idx < push_values.size()); + + if (use_cvm || no_cvm) { + memcpy(push_values[fea_idx].data() + offset + slot_offset, g, + sizeof(float) * emb_dim); + } else { + CHECK(fea_idx < fea_labels.size()); + memcpy(push_values[fea_idx].data() + offset + slot_offset, g, + sizeof(float) * emb_dim); + push_values[fea_idx][show_index] = 1.0f; + push_values[fea_idx][click_index] = + static_cast(fea_labels[fea_idx]); + } + if (dump_slot) { + push_values[fea_idx][0] = static_cast(slot); + } + g += emb_dim; + fea_idx++; + } + } + // slots whose embedding has been stop gradient or + // not involved in forward-backward + uint64_t no_grad_fea_num = 0u; + for (size_t i = sparse_grad_names.size(); i < sparse_key_names.size(); ++i) { + Variable* var = scope.FindVar(sparse_key_names[i]); + if (var == nullptr) { + continue; + } + LoDTensor* tensor = var->GetMutable(); + if (tensor == nullptr) { + LOG(ERROR) << "tensor of var[" << sparse_key_names[i] << "] is null"; + exit(-1); + } + size_t len = tensor->numel(); + int64_t* ids = tensor->data(); + for (auto id_idx = 0u; id_idx < len; ++id_idx) { + if (ids[id_idx] == 0) { + continue; + } + ++no_grad_fea_num; + } + } + CHECK(fea_idx + no_grad_fea_num == fea_keys.size()) + << "fea_idx: " << fea_idx << " no_grad_fea_num: " << no_grad_fea_num + << " features size: " << fea_keys.size(); + CHECK(fea_idx == sparse_push_keys.size()); + if (fea_idx == 0) { + return; + } + std::vector push_g_vec; + for (auto i = 0u; i < sparse_push_keys.size(); ++i) { + push_g_vec.push_back(push_values[i].data()); + } + auto status = pslib_ptr_->_worker_ptr->push_sparse( + table_id, sparse_push_keys.data(), (const float**)push_g_vec.data(), + sparse_push_keys.size()); + push_sparse_status->push_back(std::move(status)); +} + +int FleetWrapper::RegisterHeterCallback(HeterCallBackFunc handler) { +#ifdef PADDLE_WITH_PSLIB + VLOG(3) << "calling FleetWrapper::RegisterHeterCallback"; + VLOG(3) << "pslib_ptr_=" << pslib_ptr_; + VLOG(3) << "_worker_ptr=" << pslib_ptr_->_worker_ptr; + return pslib_ptr_->_worker_ptr->registe_heter_callback(handler); +#else + VLOG(0) << "FleetWrapper::RegisterHeterCallback" + << " does nothing when no pslib"; +#endif + return 0; +} + void FleetWrapper::PullSparseToLocal(const uint64_t table_id, int fea_value_dim) { #ifdef PADDLE_WITH_PSLIB diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.h b/paddle/fluid/framework/fleet/fleet_wrapper.h index afc97e01eaebd8..2d8ef11b479317 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.h +++ b/paddle/fluid/framework/fleet/fleet_wrapper.h @@ -30,6 +30,7 @@ limitations under the License. */ #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/device_worker.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN @@ -80,6 +81,23 @@ class FleetWrapper { pull_local_thread_num_ = thread_num; } + void HeterPullSparseVars(int workerid, std::shared_ptr task, const uint64_t table_id, + const std::vector& var_names, + int fea_dim, + const std::vector& var_emb_names); + + void HeterPushSparseVars( + std::shared_ptr task, const uint64_t table_id, + const std::vector& sparse_key_names, + const std::vector& sparse_grad_names, const int emb_dim, + std::vector<::std::future>* push_sparse_status, + const bool use_cvm, const bool dump_slot, + const bool no_cvm); + + typedef std::function HeterCallBackFunc; + + int RegisterHeterCallback(HeterCallBackFunc handler); + // Pull sparse variables from server in sync mode // Param: scope, table_id, var_names, fea_keys, fea_dim, var_emb_names // Param: fea_values diff --git a/paddle/fluid/framework/hetercpu_worker.cc b/paddle/fluid/framework/hetercpu_worker.cc new file mode 100644 index 00000000000000..fec2bccb38f611 --- /dev/null +++ b/paddle/fluid/framework/hetercpu_worker.cc @@ -0,0 +1,1212 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/device_worker.h" +#include "paddle/fluid/framework/device_worker_factory.h" +#include "paddle/fluid/framework/fleet/fleet_wrapper.h" +#include "paddle/fluid/platform/cpu_helper.h" +#include "paddle/fluid/string/string_helper.h" + +#if defined _WIN32 || defined __APPLE__ +#else +#define _LINUX +#endif + +namespace paddle { +namespace framework { + +void HeterTask::PackTask(Scope* thread_scope, int taskid, DataFeed* reader, int cur_batch, const ProgramDesc& program) { + total_time = 0; + read_time = 0; + pack_time = 0; + pull_sparse_local_time = 0; + taskid_ = taskid; + auto &block = program.Block(0); + if (!scope_) { + scope_ = &(thread_scope->NewScope()); + for (auto &var : block.AllVars()) { + if (!var->Persistable()) { + auto *ptr = scope_->Var(var->Name()); + InitializeVariable(ptr, var->GetType()); + } + } + } + state_ = PULL_SPARSE; + cur_batch_ = cur_batch; + auto& use_slots = reader->GetUseSlotAlias(); + for (size_t i = 0; i < use_slots.size(); ++i) { + Variable* thread_var = thread_scope->FindVar(use_slots[i]); + LoDTensor* thread_tensor = thread_var->GetMutable(); + Variable* task_var = scope_->FindVar(use_slots[i]); + LoDTensor* task_tensor = task_var->GetMutable(); + TensorCopy(*thread_tensor, platform::CPUPlace(), task_tensor); + auto& tensor_lod = thread_tensor->lod()[0]; + LoD thread_lod{tensor_lod}; + task_tensor->set_lod(thread_lod); + } + +} + +void HeterCpuWorker::Schedule(int taskid) { + //std::cout << "wxx schedule " << taskid << std::endl; + auto task = wait_queue_.TryGet(taskid); + if (task) { + run_queue_.Put(task->taskid_, task); + } +} + +void HeterCpuWorker::JumpContext(std::shared_ptr task) { + //std::cout << "wxx jump context " << task->taskid_ << std::endl; + if (!(wait_queue_.TryPut(task->taskid_, task))) { + run_queue_.Put(task->taskid_, task); + } +} + +void HeterCpuWorker::Initialize(const TrainerDesc& desc) { + param_ = desc.downpour_param(); + for (int i = 0; i < param_.sparse_table_size(); ++i) { + uint64_t table_id = + static_cast(param_.sparse_table(i).table_id()); + TableParameter table = param_.sparse_table(i); + sparse_key_names_[table_id].resize(table.sparse_key_name_size()); + for (int j = 0; j < table.sparse_key_name_size(); ++j) { + sparse_key_names_[table_id][j] = table.sparse_key_name(j); + } + sparse_value_names_[table_id].resize(table.sparse_value_name_size()); + for (int j = 0; j < table.sparse_value_name_size(); ++j) { + sparse_value_names_[table_id][j] = table.sparse_value_name(j); + } + sparse_grad_names_[table_id].resize(table.sparse_grad_name_size()); + for (int j = 0; j < table.sparse_grad_name_size(); ++j) { + sparse_grad_names_[table_id][j] = table.sparse_grad_name(j); + } + label_var_name_[table_id] = table.label_var_name(); + sparse_push_keys_[table_id] = std::vector(); + } + + for (int i = 0; i < param_.dense_table_size(); ++i) { + uint64_t table_id = static_cast(param_.dense_table(i).table_id()); + auto table = param_.dense_table(i); + dense_value_names_[table_id].resize(table.dense_value_name_size()); + for (int j = 0; j < table.dense_value_name_size(); ++j) { + dense_value_names_[table_id][j] = table.dense_value_name(j); + } + dense_grad_names_[table_id].resize(table.dense_grad_name_size()); + for (int j = 0; j < table.dense_grad_name_size(); ++j) { + dense_grad_names_[table_id][j] = table.dense_grad_name(j); + } + } + + skip_ops_.resize(param_.skip_ops_size()); + for (int i = 0; i < param_.skip_ops_size(); ++i) { + skip_ops_[i] = param_.skip_ops(i); + } + for (int i = 0; i < param_.stat_var_names_size(); ++i) { + stat_var_name_map_[param_.stat_var_names(i)] = 1; + } + + need_to_push_sparse_ = param_.push_sparse(); + need_to_push_dense_ = param_.push_dense(); + + fleet_ptr_ = FleetWrapper::GetInstance(); + fetch_config_ = desc.fetch_config(); + use_cvm_ = desc.use_cvm(); + // for sparse value accessor, embedding only + no_cvm_ = desc.no_cvm(); + scale_datanorm_ = desc.scale_datanorm(); + dump_slot_ = desc.dump_slot(); + dump_fields_.resize(desc.dump_fields_size()); + for (int i = 0; i < desc.dump_fields_size(); ++i) { + dump_fields_[i] = desc.dump_fields(i); + } + adjust_ins_weight_config_ = desc.adjust_ins_weight_config(); + need_dump_param_ = false; + dump_param_.resize(desc.dump_param_size()); + for (int i = 0; i < desc.dump_param_size(); ++i) { + dump_param_[i] = desc.dump_param(i); + } + if (desc.dump_param_size() != 0) { + need_dump_param_ = true; + } + for (int i = 0; i < desc.check_nan_var_names_size(); ++i) { + check_nan_var_names_.push_back(desc.check_nan_var_names(i)); + } + copy_table_config_ = desc.copy_table_config(); + for (int i = 0; i < copy_table_config_.src_sparse_tables_size(); ++i) { + uint64_t src_table = copy_table_config_.src_sparse_tables(i); + uint64_t dest_table = copy_table_config_.dest_sparse_tables(i); + VLOG(3) << "copy_sparse_tables_ push back " << src_table << "->" + << dest_table; + copy_sparse_tables_.push_back(std::make_pair(src_table, dest_table)); + } + for (int i = 0; i < copy_table_config_.src_dense_tables_size(); ++i) { + uint64_t src_table = copy_table_config_.src_dense_tables(i); + uint64_t dest_table = copy_table_config_.dest_dense_tables(i); + VLOG(3) << "copy_dense_tables_ push back " << src_table << "->" + << dest_table; + copy_dense_tables_.push_back(std::make_pair(src_table, dest_table)); + } + for (auto& m : copy_table_config_.table_denpendency_map()) { + if (sparse_key_names_.find(m.key()) != sparse_key_names_.end()) { + // currently only support one dependency + for (auto& value : m.values()) { + table_dependency_[m.key()] = value; + } + } + } +} + +void HeterCpuWorker::SetChannelWriter(ChannelObject* queue) { + writer_.Reset(queue); +} + +void HeterCpuWorker::SetNeedDump(bool need_dump_field) { + need_dump_field_ = need_dump_field; +} + +//template +//std::string PrintLodTensorType(LoDTensor* tensor, int64_t start, int64_t end) { +// auto count = tensor->numel(); +// if (start < 0 || end > count) { +// VLOG(3) << "access violation"; +// return "access violation"; +// } +// std::ostringstream os; +// for (int64_t i = start; i < end; i++) { +// os << ":" << tensor->data()[i]; +// } +// return os.str(); +//} +// +//std::string PrintLodTensorIntType(LoDTensor* tensor, int64_t start, +// int64_t end) { +// auto count = tensor->numel(); +// if (start < 0 || end > count) { +// VLOG(3) << "access violation"; +// return "access violation"; +// } +// std::ostringstream os; +// for (int64_t i = start; i < end; i++) { +// os << ":" << static_cast(tensor->data()[i]); +// } +// return os.str(); +//} +// +//std::string PrintLodTensor(LoDTensor* tensor, int64_t start, int64_t end) { +// std::string out_val; +// if (tensor->type() == proto::VarType::FP32) { +// out_val = PrintLodTensorType(tensor, start, end); +// } else if (tensor->type() == proto::VarType::INT64) { +// out_val = PrintLodTensorIntType(tensor, start, end); +// } else if (tensor->type() == proto::VarType::FP64) { +// out_val = PrintLodTensorType(tensor, start, end); +// } else { +// out_val = "unsupported type"; +// } +// return out_val; +//} +// +//std::pair GetTensorBound(LoDTensor* tensor, int index) { +// auto& dims = tensor->dims(); +// if (tensor->lod().size() != 0) { +// auto& lod = tensor->lod()[0]; +// return {lod[index] * dims[1], lod[index + 1] * dims[1]}; +// } else { +// return {index * dims[1], (index + 1) * dims[1]}; +// } +//} +// +//bool CheckValidOutput(LoDTensor* tensor, size_t batch_size) { +// auto& dims = tensor->dims(); +// if (dims.size() != 2) return false; +// if (tensor->lod().size() != 0) { +// auto& lod = tensor->lod()[0]; +// if (lod.size() != batch_size + 1) { +// return false; +// } +// } else { +// if (dims[0] != static_cast(batch_size)) { +// return false; +// } +// } +// return true; +//} + +void HeterCpuWorker::DumpParam() { +// std::string os; +// for (auto& param : dump_param_) { +// os.clear(); +// os = param; +// Variable* var = thread_scope_->FindVar(param); +// if (var == nullptr) { +// continue; +// } +// LoDTensor* tensor = var->GetMutable(); +// int64_t len = tensor->numel(); +// os += PrintLodTensor(tensor, 0, len); +// writer_ << os; +// } +} + +void HeterCpuWorker::CollectLabelInfo(std::shared_ptr task, size_t table_idx) { + if (no_cvm_) { + return; + } + uint64_t table_id = static_cast( + param_.program_config(0).pull_sparse_table_id(table_idx)); + + TableParameter table; + for (auto i : param_.sparse_table()) { + if (i.table_id() == table_id) { + table = i; + break; + } + } + auto& feature = (task->features_)[table_id]; + auto& feature_label = (task->feature_labels_)[table_id]; + Scope* scope = task->scope_; + feature_label.resize(feature.size()); + Variable* var = scope->FindVar(label_var_name_[table_id]); + LoDTensor* tensor = var->GetMutable(); + int64_t* label_ptr = tensor->data(); + + size_t global_index = 0; + for (size_t i = 0; i < sparse_key_names_[table_id].size(); ++i) { + VLOG(3) << "sparse_key_names_[" << i + << "]: " << sparse_key_names_[table_id][i]; + Variable* fea_var = scope->FindVar(sparse_key_names_[table_id][i]); + if (fea_var == nullptr) { + continue; + } + LoDTensor* tensor = fea_var->GetMutable(); + CHECK(tensor != nullptr) << "tensor of var " + << sparse_key_names_[table_id][i] << " is null"; + + // skip slots which do not have embedding + Variable* emb_var = + scope->FindVar(sparse_value_names_[table_id][i]); + if (emb_var == nullptr) { + continue; + } + int64_t* ids = tensor->data(); + size_t fea_idx = 0; + // tensor->lod()[0].size() == batch_size + 1 + for (auto lod_idx = 1u; lod_idx < tensor->lod()[0].size(); ++lod_idx) { + for (; fea_idx < tensor->lod()[0][lod_idx]; ++fea_idx) { + // should be skipped feasign defined in protobuf + if (ids[fea_idx] == 0u) { + continue; + } + feature_label[global_index++] = + static_cast(label_ptr[lod_idx - 1]); + } + } + } + CHECK(global_index == feature.size()) + << "expect fea info size:" << feature.size() << " real:" << global_index; +} + +void HeterCpuWorker::FillSparseValue(std::shared_ptr task, size_t table_idx) { + uint64_t table_id = static_cast( + param_.program_config(0).pull_sparse_table_id(table_idx)); + + TableParameter table; + for (auto i : param_.sparse_table()) { + if (i.table_id() == table_id) { + table = i; + break; + } + } + + auto& fea_value = (task->feature_values_)[table_id]; + Scope* scope= task->scope_; + auto fea_idx = 0u; + + std::vector init_value(table.fea_dim()); + for (size_t i = 0; i < sparse_key_names_[table_id].size(); ++i) { + std::string slot_name = sparse_key_names_[table_id][i]; + std::string emb_slot_name = sparse_value_names_[table_id][i]; + Variable* var = scope->FindVar(slot_name); + if (var == nullptr) { + continue; + } + LoDTensor* tensor = var->GetMutable(); + CHECK(tensor != nullptr) << "tensor of var " << slot_name << " is null"; + int64_t* ids = tensor->data(); + int len = tensor->numel(); + Variable* var_emb = scope->FindVar(emb_slot_name); + if (var_emb == nullptr) { + continue; + } + LoDTensor* tensor_emb = var_emb->GetMutable(); + float* ptr = tensor_emb->mutable_data({len, table.emb_dim()}, + place_); + //memset(ptr, 0, sizeof(float) * len * table.emb_dim()); + auto& tensor_lod = tensor->lod()[0]; + LoD data_lod{tensor_lod}; + tensor_emb->set_lod(data_lod); + + bool is_nid = (adjust_ins_weight_config_.need_adjust() && + adjust_ins_weight_config_.nid_slot() == emb_slot_name); + if (is_nid) { + nid_show_.clear(); + } + int nid_ins_index = 0; + + for (int index = 0; index < len; ++index) { + if (use_cvm_ || no_cvm_) { + if (ids[index] == 0u) { + memcpy(ptr + table.emb_dim() * index, init_value.data(), + sizeof(float) * table.emb_dim()); + if (is_nid) { + nid_show_.push_back(-1); + ++nid_ins_index; + } + continue; + } + memcpy(ptr + table.emb_dim() * index, fea_value[fea_idx].data(), + sizeof(float) * table.emb_dim()); + if (is_nid && + static_cast(index) == tensor->lod()[0][nid_ins_index]) { + nid_show_.push_back(fea_value[fea_idx][0]); + ++nid_ins_index; + } + fea_idx++; + } else { + if (ids[index] == 0u) { + memcpy(ptr + table.emb_dim() * index, init_value.data() + 2, + sizeof(float) * table.emb_dim()); + if (is_nid) { + nid_show_.push_back(-1); + ++nid_ins_index; + } + continue; + } + memcpy(ptr + table.emb_dim() * index, fea_value[fea_idx].data() + 2, + sizeof(float) * table.emb_dim()); + if (is_nid && + static_cast(index) == tensor->lod()[0][nid_ins_index]) { + nid_show_.push_back(fea_value[fea_idx][0]); + ++nid_ins_index; + } + fea_idx++; + } + } + } +} + +void HeterCpuWorker::AdjustInsWeight(std::shared_ptr task) { +#ifdef _LINUX + // check var and tensor not null + Scope* scope = task->scope_; + if (!adjust_ins_weight_config_.need_adjust()) { + VLOG(0) << "need_adjust=false, skip adjust ins weight"; + return; + } + Variable* nid_var = + scope->FindVar(adjust_ins_weight_config_.nid_slot()); + if (nid_var == nullptr) { + VLOG(0) << "nid slot var " << adjust_ins_weight_config_.nid_slot() + << " is nullptr, skip adjust ins weight"; + return; + } + LoDTensor* nid_tensor = nid_var->GetMutable(); + if (nid_tensor == nullptr) { + VLOG(0) << "tensor of nid slot var " << adjust_ins_weight_config_.nid_slot() + << " is nullptr, skip adjust ins weight"; + return; + } + Variable* ins_weight_var = + scope->FindVar(adjust_ins_weight_config_.ins_weight_slot()); + if (ins_weight_var == nullptr) { + VLOG(0) << "ins weight var " << adjust_ins_weight_config_.ins_weight_slot() + << " is nullptr, skip adjust ins weight"; + return; + } + LoDTensor* ins_weight_tensor = ins_weight_var->GetMutable(); + if (ins_weight_tensor == nullptr) { + VLOG(0) << "tensor of ins weight tensor " + << adjust_ins_weight_config_.ins_weight_slot() + << " is nullptr, skip adjust ins weight"; + return; + } + + float* ins_weights = ins_weight_tensor->data(); + size_t len = ins_weight_tensor->numel(); // len = batch size + // here we assume nid_show slot only has one feasign in each instance + CHECK(len == nid_show_.size()) << "ins_weight size should be equal to " + << "nid_show size, " << len << " vs " + << nid_show_.size(); + float nid_adjw_threshold = adjust_ins_weight_config_.nid_adjw_threshold(); + float nid_adjw_ratio = adjust_ins_weight_config_.nid_adjw_ratio(); + int64_t nid_adjw_num = 0; + double nid_adjw_weight = 0.0; + size_t ins_index = 0; + for (size_t i = 0; i < len; ++i) { + float nid_show = nid_show_[i]; + VLOG(3) << "nid_show " << nid_show; + if (nid_show < 0) { + VLOG(3) << "nid_show < 0, continue"; + continue; + } + float ins_weight = 1.0; + if (nid_show >= 0 && nid_show < nid_adjw_threshold) { + ins_weight = log(M_E + + (nid_adjw_threshold - nid_show) / nid_adjw_threshold * + nid_adjw_ratio); + // count nid adjw insnum and weight + ++nid_adjw_num; + nid_adjw_weight += ins_weight; + // choose large ins weight + VLOG(3) << "ins weight new " << ins_weight << ", ins weight origin " + << ins_weights[ins_index]; + if (ins_weight > ins_weights[ins_index]) { + VLOG(3) << "ins " << ins_index << " weight changes to " << ins_weight; + ins_weights[ins_index] = ins_weight; + } + ++ins_index; + } + } + VLOG(3) << "nid adjw info: total_adjw_num: " << nid_adjw_num + << ", avg_adjw_weight: " << nid_adjw_weight; +#endif +} + +void HeterCpuWorker::CopySparseTable() { + for (size_t i = 0; i < copy_sparse_tables_.size(); ++i) { + int64_t src_table = copy_sparse_tables_[i].first; + int64_t dest_table = copy_sparse_tables_[i].second; + int32_t feanum = 0; + if (src_table == dest_table) { + continue; + } else if (!copy_table_config_.sparse_copy_by_feasign()) { + if (feasign_set_.find(src_table) == feasign_set_.end()) { + continue; + } else if (feasign_set_[src_table].size() == 0) { + continue; + } + feanum = fleet_ptr_->CopyTable(src_table, dest_table); + } else { + std::vector fea_vec(feasign_set_[src_table].begin(), + feasign_set_[src_table].end()); + feanum = fleet_ptr_->CopyTableByFeasign(src_table, dest_table, fea_vec); + fea_vec.clear(); + std::vector().swap(fea_vec); + } + VLOG(3) << "copy feasign from table " << src_table << " to table " + << dest_table << ", feasign num=" << feanum; + feasign_set_[src_table].clear(); + std::unordered_set().swap(feasign_set_[src_table]); + } + feasign_set_.clear(); +} + +void HeterCpuWorker::CopyDenseTable() { + if (thread_id_ != 0) { + return; + } + thread_local std::vector> pull_dense_status; + for (size_t i = 0; i < copy_dense_tables_.size(); ++i) { + uint64_t src_table = copy_dense_tables_[i].first; + uint64_t dest_table = copy_dense_tables_[i].second; + if (src_table == dest_table) { + continue; + } + int32_t dim = fleet_ptr_->CopyTable(src_table, dest_table); + VLOG(3) << "copy param from table " << src_table << " to table " + << dest_table << ", dim=" << dim; + if (copy_table_config_.dense_pull_after_copy()) { + VLOG(3) << "dense pull after copy, table=" << dest_table; + pull_dense_status.resize(0); + //fleet_ptr_->PullDenseVarsAsync(*root_scope_, dest_table, + // dense_value_names_[dest_table], + // &pull_dense_status); + for (auto& t : pull_dense_status) { + t.wait(); + auto status = t.get(); + if (status != 0) { + LOG(WARNING) << "pull dense after copy table failed," + << " table=" << dest_table; + } + } + } + } +} + +void HeterCpuWorker::CreateThreadParam(const ProgramDesc& program) { + #ifdef PADDLE_WITH_CUDA + auto dev_id = boost::get(place_).device; + platform::CUDADeviceGuard guard(dev_id); + auto &block = program.Block(0); + for (auto& var : block.AllVars()) { + if (var->Persistable()) { + auto name = var->Name(); + Variable* root_var = root_scope_->FindVar(name); + LoDTensor* root_tensor = root_var->GetMutable(); + auto *ptr = thread_scope_->Var(name); + InitializeVariable(ptr, proto::VarType::LOD_TENSOR); + LoDTensor* thread_tensor = ptr->GetMutable(); + +#define MemcpyCallback(cpp_type, proto_type) \ + do { \ + if (root_tensor->type() == proto_type) { \ + MemCpy(thread_tensor, root_tensor, place_, copy_stream_); \ + } \ + } while (0) + _ForEachDataType_(MemcpyCallback); + + } + } + PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(event_, copy_stream_)); + cudaEventSynchronize(event_); + #endif +} + +#ifdef PADDLE_WITH_CUDA +template +void HeterCpuWorker::MemCpy(LoDTensor *thread_tensor, LoDTensor *root_tensor, + const paddle::platform::Place& thread_place, + cudaStream_t stream) { + T* thread_ptr = thread_tensor->mutable_data(root_tensor->dims(), thread_place); + T* root_ptr = root_tensor->data(); + if (platform::is_cpu_place(root_tensor->place())) { + memory::Copy( + boost::get(thread_place), + thread_ptr, + platform::CPUPlace(), + root_ptr, sizeof(T) * root_tensor->numel(), stream); + } + else { + memory::Copy( + boost::get(thread_place), + thread_ptr, + boost::get(root_tensor->place()), + root_ptr, sizeof(T) * root_tensor->numel(), stream); + } +} +#endif + +void HeterCpuWorker::CopyDenseVars() { + if (thread_id_ != 0) { + return; + } + for (int i = 0; i < copy_table_config_.src_var_list_size(); ++i) { + auto& src_var_name = copy_table_config_.src_var_list(i); + auto& dest_var_name = copy_table_config_.dest_var_list(i); + if (src_var_name == dest_var_name) { + continue; + } + VLOG(3) << "copy dense var from " << src_var_name << " to " + << dest_var_name; + Variable* src_var = thread_scope_->FindVar(src_var_name); + CHECK(src_var != nullptr) << src_var_name << " not found"; // NOLINT + LoDTensor* src_tensor = src_var->GetMutable(); + CHECK(src_tensor != nullptr) << src_var_name + << " tensor is null"; // NOLINT + float* src_data = src_tensor->data(); + + Variable* dest_var = thread_scope_->FindVar(dest_var_name); + CHECK(dest_var != nullptr) << dest_var_name << " not found"; // NOLINT + LoDTensor* dest_tensor = dest_var->GetMutable(); + CHECK(dest_tensor != nullptr) << dest_var_name + << " tensor is null"; // NOLINT + float* dest_data = dest_tensor->data(); + + CHECK(src_tensor->numel() == dest_tensor->numel()) + << "tensor numel not equal," << src_tensor->numel() << " vs " + << dest_tensor->numel(); + for (int i = 0; i < src_tensor->numel(); i++) { + dest_data[i] = src_data[i]; + } + } +} + +void HeterCpuWorker::TrainFilesWithProfiler() { + VLOG(3) << "Begin to train files with profiler"; + platform::SetNumThreads(1); + device_reader_->Start(); + + std::vector op_total_time; + std::vector op_name; + for (auto& op : ops_) { + bool need_skip = false; + for (auto t = 0u; t < skip_ops_.size(); ++t) { + if (op->Type().find(skip_ops_[t]) != std::string::npos) { + need_skip = true; + break; + } + } + if (!need_skip) { + op_name.push_back(op->Type()); + } + } + + VLOG(3) << "op name size: " << op_name.size(); + op_total_time.resize(op_name.size()); + for (size_t i = 0; i < op_total_time.size(); ++i) { + op_total_time[i] = 0.0; + } + platform::Timer timeline; + double total_time = 0.0; + double read_time = 0.0; + //double pull_sparse_time = 0.0; + double collect_label_time = 0.0; + double fill_sparse_time = 0.0; + double push_sparse_time = 0.0; + double push_dense_time = 0.0; + double pack_time = 0.0; + double pull_sparse_local_time = 0.0; + + int batch_cnt = 0; + int done_cnt = 0; + int cur_batch; + uint64_t total_inst = 0; + wait_queue_.SetCap(3); + while (1) { + + std::shared_ptr task; + task = run_queue_.Get(); + if (!task) { + double tmp_read_time; + timeline.Start(); + cur_batch = device_reader_->Next(); + timeline.Pause(); + tmp_read_time = timeline.ElapsedSec(); + if (cur_batch <= 0) { + if (batch_cnt == done_cnt) { + break; + } + else { + continue; + } + } + batch_cnt += 1; + int taskid = batch_cnt * worker_num_ + thread_id_; + timeline.Start(); + task = object_pool_.Get(); + task->PackTask(thread_scope_, taskid, device_reader_, cur_batch, program_); + timeline.Pause(); + task->read_time = tmp_read_time; + task->pack_time = timeline.ElapsedSec(); + task->total_time = tmp_read_time + task->pack_time; + } + for (;;) { + // pull sparse here + if (task->state_ == PULL_SPARSE) { + timeline.Start(); + for (int i = 0; i < param_.program_config(0).pull_sparse_table_id_size(); + ++i) { + uint64_t tid = static_cast( + param_.program_config(0).pull_sparse_table_id(i)); + TableParameter table; + for (auto j : param_.sparse_table()) { + if (j.table_id() == tid) { + table = j; + break; + } + } + fleet_ptr_->HeterPullSparseVars(thread_id_, + task, tid, sparse_key_names_[tid], + table.fea_dim(), sparse_value_names_[tid]); + } + task->Update(); + JumpContext(task); + timeline.Pause(); + task->pull_sparse_local_time += timeline.ElapsedSec(); + task->total_time += timeline.ElapsedSec(); + break; + } + else if (task->state_ == OP_RUN) { + total_time += task->total_time; + read_time += task->read_time; + pack_time += task->pack_time; + pull_sparse_local_time += task->pull_sparse_local_time; + for (int i = 0; i < param_.program_config(0).pull_sparse_table_id_size(); + ++i) { + uint64_t tid = static_cast( + param_.program_config(0).pull_sparse_table_id(i)); + timeline.Start(); + CollectLabelInfo(task, i); + timeline.Pause(); + collect_label_time += timeline.ElapsedSec(); + total_time += timeline.ElapsedSec(); + timeline.Start(); + FillSparseValue(task, i); + timeline.Pause(); + fill_sparse_time += timeline.ElapsedSec(); + total_time += timeline.ElapsedSec(); + + auto nid_iter = std::find(sparse_value_names_[tid].begin(), + sparse_value_names_[tid].end(), + adjust_ins_weight_config_.nid_slot()); + if (nid_iter != sparse_value_names_[tid].end()) { + AdjustInsWeight(task); + } + } + + VLOG(3) << "fill sparse value for all sparse table done."; + // do computation here + int run_op_idx = 0; + for (auto& op : ops_) { + bool need_skip = false; + for (auto t = 0u; t < skip_ops_.size(); ++t) { + if (op->Type().find(skip_ops_[t]) != std::string::npos) { + need_skip = true; + break; + } + } + if (!need_skip) { + timeline.Start(); + op->Run(*(task->scope_), place_); + timeline.Pause(); + op_total_time[run_op_idx++] += timeline.ElapsedSec(); + total_time += timeline.ElapsedSec(); + } + } + // check inf and nan + for (std::string& var_name : check_nan_var_names_) { + Variable* var = (task->scope_)->FindVar(var_name); + if (var == nullptr) { + continue; + } + LoDTensor* tensor = var->GetMutable(); + if (tensor == nullptr) { + continue; + } + PADDLE_ENFORCE_EQ(framework::TensorContainsInf(*tensor), false, + "Tensor %s contains Inf", var_name); + PADDLE_ENFORCE_EQ(framework::TensorContainsNAN(*tensor), false, + "Tensor %s contains NAN", var_name); + } + task->Update(); + } + else if (task->state_ == PUSH_GRAD) { + if (need_to_push_sparse_) { + // push gradients here + for (int i = 0; i < param_.program_config(0).push_sparse_table_id_size(); + ++i) { + uint64_t tid = static_cast( + param_.program_config(0).push_sparse_table_id(i)); + TableParameter table; + for (auto i : param_.sparse_table()) { + if (i.table_id() == tid) { + table = i; + break; + } + } + timeline.Start(); + fleet_ptr_->HeterPushSparseVars( + task, tid, + sparse_key_names_[tid], sparse_grad_names_[tid], table.emb_dim(), + &push_sparse_status_, use_cvm_, + dump_slot_, no_cvm_); + timeline.Pause(); + push_sparse_time += timeline.ElapsedSec(); + total_time += timeline.ElapsedSec(); + } + } + if (need_to_push_dense_) { + timeline.Start(); + for (int i = 0; i < param_.program_config(0).push_dense_table_id_size(); + ++i) { + uint64_t tid = static_cast( + param_.program_config(0).push_dense_table_id(i)); + fleet_ptr_->PushDenseVarsAsync( + *(task->scope_), tid, dense_grad_names_[tid], &push_sparse_status_, + scale_datanorm_, task->cur_batch_); + } + timeline.Pause(); + push_dense_time += timeline.ElapsedSec(); + total_time += timeline.ElapsedSec(); + VLOG(3) << "push dense gradient done."; + + // the following code should be more precise and clean + // TODO(guru4elephant) + int32_t tmp_push_dense_wait_times = -1; + static uint32_t push_dense_wait_times = + static_cast(tmp_push_dense_wait_times); + + if (push_dense_status_.size() >= push_dense_wait_times) { + for (auto& t : push_dense_status_) { + t.wait(); + } + push_dense_status_.resize(0); + } + + if (tmp_push_dense_wait_times == -1) { + push_dense_status_.resize(0); + } + } + + if (need_to_push_sparse_) { + VLOG(3) << "push sparse gradient done."; + int32_t tmp_push_sparse_wait_times = -1; + static uint32_t push_sparse_wait_times = + static_cast(tmp_push_sparse_wait_times); + if (push_sparse_status_.size() >= push_sparse_wait_times) { + for (auto& t : push_sparse_status_) { + t.wait(); + } + push_sparse_status_.resize(0); + } + + if (tmp_push_sparse_wait_times == -1) { + push_sparse_status_.resize(0); + } + } + + if (need_to_push_dense_) { + for (int i = 0; i < param_.program_config(0).push_dense_table_id_size(); + ++i) { + uint64_t tid = static_cast( + param_.program_config(0).push_dense_table_id(i)); + pull_dense_worker_->IncreaseThreadVersion(thread_id_, tid); + } + } + + //thread_scope_->DropKids(); + task->Update(); + } + else if (task->state_ == DONE) { + PrintFetchVars(); + ++done_cnt; + total_inst += task->cur_batch_; + object_pool_.Push(task); + //++batch_cnt; + if (thread_id_ == 0) { + // should be configured here + if (done_cnt > 0 && done_cnt % 100 == 0) { + double op_sum_time = 0; + std::unordered_map op_to_time; + for (size_t i = 0; i < op_total_time.size(); ++i) { + fprintf(stderr, "op_name:[%zu][%s], op_mean_time:[%fs]\n", i, + op_name[i].c_str(), op_total_time[i] / done_cnt); + if (op_to_time.find(op_name[i]) == op_to_time.end()) { + op_to_time[op_name[i]] = 0.0; + } + op_to_time[op_name[i]] += op_total_time[i]; + op_sum_time += op_total_time[i]; + } + for (auto& i : op_to_time) { + fprintf(stderr, "op [%s] run total time: [%f]ms\n", i.first.c_str(), + i.second / done_cnt); + } + fprintf(stderr, "op run total time: %fs\n", op_sum_time / done_cnt); + fprintf(stderr, "pack task time: %fs\n", pack_time / done_cnt); + fprintf(stderr, "train total time: %fs\n", total_time / done_cnt); + fprintf(stderr, "pull sparse local time: %fs\n", + pull_sparse_local_time / done_cnt); + fprintf(stderr, "fill sparse time: %fs\n", + fill_sparse_time / done_cnt); + fprintf(stderr, "push sparse time: %fs\n", + push_sparse_time / done_cnt); + fprintf(stderr, "push dense time: %fs\n", push_dense_time / done_cnt); + fprintf(stderr, "collect label time: %fs\n", + collect_label_time / done_cnt); + fprintf(stderr, "mean read time: %fs\n", read_time / done_cnt); + fprintf(stderr, "IO percent: %f\n", read_time / total_time * 100); + fprintf(stderr, "op run percent: %f\n", op_sum_time / total_time * 100); + fprintf(stderr, "pack task percent: %f\n", pack_time / total_time * 100); + fprintf(stderr, "pull sparse local time percent: %f\n", + pull_sparse_local_time / total_time * 100); + fprintf(stderr, "collect label time percent: %f\n", + collect_label_time / total_time * 100); + fprintf(stderr, "fill sparse time percent: %f\n", + fill_sparse_time / total_time * 100); + fprintf(stderr, "push sparse time percent: %f\n", + push_sparse_time / total_time * 100); + fprintf(stderr, "push dense time percent: %f\n", + push_dense_time / total_time * 100); + fprintf(stderr, "%6.2f instances/s\n", total_inst / total_time); + } + } + break; + } + } + } + if (copy_table_config_.need_copy()) { + CopySparseTable(); + CopyDenseTable(); + CopyDenseVars(); + } +} + +void HeterCpuWorker::TrainFiles() { + VLOG(3) << "Begin to train files"; + platform::SetNumThreads(1); + device_reader_->Start(); + int batch_cnt = 0; + int done_cnt = 0; + int cur_batch; + wait_queue_.SetCap(3); + //while ((cur_batch = device_reader_->Next()) > 0) { + while (1) { + //if (copy_table_config_.need_copy()) { + // if (copy_table_config_.sparse_copy_by_feasign()) { + // for (size_t i = 0; i < copy_sparse_tables_.size(); ++i) { + // uint64_t tid = copy_sparse_tables_[i].first; + // feasign_set_[tid].insert(sparse_push_keys_[tid].begin(), + // sparse_push_keys_[tid].end()); + // } + // } + // if (batch_cnt % copy_table_config_.batch_num() == 0) { + // CopySparseTable(); + // CopyDenseTable(); + // CopyDenseVars(); + // } + //} + + std::shared_ptr task; + //std::cout << "wait_queue size:" << wait_queue_.Size() << " run_queue size:" << run_queue_.Size() << std::endl; + //std::cout << "object pool size: " << object_pool_.Size() << std::endl; + + // while (wait_queue_.Size() > 10) { + // std::cout << "sleep 10ms" << std::endl; + // usleep(10000); + // } + task = run_queue_.Get(); + //std::cout << "wxx begin " << std::endl; + if (!task) { + //std::cout << "wxx new pack " << std::endl; + cur_batch = device_reader_->Next(); + //std::cout << "wxx " << cur_batch << " " << wait_queue_.Empty() << std::endl; + if (cur_batch <= 0) { + if (batch_cnt == done_cnt) { + //std::cout << "wxx pass done " << std::endl; + break; + } + else { + continue; + } + } + batch_cnt += 1; + int taskid = batch_cnt * worker_num_ + thread_id_; + //std::cout << "taskid " << taskid << " " << batch_cnt << " " << worker_num_ << " " << thread_id_ << std::endl; + task = object_pool_.Get(); + task->PackTask(thread_scope_, taskid, device_reader_, cur_batch, program_); + } + //task->Show(); + for (;;) { + // pull sparse here + if (task->state_ == PULL_SPARSE) { + //std::cout << "wxx pull sparse taskid = " << task->taskid_ << std::endl; + for (int i = 0; i < param_.program_config(0).pull_sparse_table_id_size(); + ++i) { + uint64_t tid = static_cast( + param_.program_config(0).pull_sparse_table_id(i)); + TableParameter table; + for (auto j : param_.sparse_table()) { + if (j.table_id() == tid) { + table = j; + break; + } + } + fleet_ptr_->HeterPullSparseVars(thread_id_, + task, tid, sparse_key_names_[tid], + table.fea_dim(), sparse_value_names_[tid]); + } + task->Update(); + JumpContext(task); + break; + } + else if (task->state_ == OP_RUN) { + //std::cout << "wxx oprun taskid = " << task->taskid_ << std::endl; + for (int i = 0; i < param_.program_config(0).pull_sparse_table_id_size(); + ++i) { + uint64_t tid = static_cast( + param_.program_config(0).pull_sparse_table_id(i)); + CollectLabelInfo(task, i); + FillSparseValue(task, i); + auto nid_iter = std::find(sparse_value_names_[tid].begin(), + sparse_value_names_[tid].end(), + adjust_ins_weight_config_.nid_slot()); + if (nid_iter != sparse_value_names_[tid].end()) { + AdjustInsWeight(task); + } + } + + VLOG(3) << "fill sparse value for all sparse table done."; + // do computation here + for (auto& op : ops_) { + bool need_skip = false; + for (auto t = 0u; t < skip_ops_.size(); ++t) { + if (op->Type().find(skip_ops_[t]) != std::string::npos) { + need_skip = true; + break; + } + } + if (!need_skip) { + op->Run(*(task->scope_), place_); + } + } + // check inf and nan + for (std::string& var_name : check_nan_var_names_) { + Variable* var = (task->scope_)->FindVar(var_name); + if (var == nullptr) { + continue; + } + LoDTensor* tensor = var->GetMutable(); + if (tensor == nullptr) { + continue; + } + PADDLE_ENFORCE_EQ(framework::TensorContainsInf(*tensor), false, + "Tensor %s contains Inf", var_name); + PADDLE_ENFORCE_EQ(framework::TensorContainsNAN(*tensor), false, + "Tensor %s contains NAN", var_name); + } + task->Update(); + } + else if (task->state_ == PUSH_GRAD) { + //std::cout << "wxx push grad taskid = " << task->taskid_ << std::endl; + if (need_to_push_sparse_) { + // push gradients here + for (int i = 0; i < param_.program_config(0).push_sparse_table_id_size(); + ++i) { + uint64_t tid = static_cast( + param_.program_config(0).push_sparse_table_id(i)); + TableParameter table; + for (auto i : param_.sparse_table()) { + if (i.table_id() == tid) { + table = i; + break; + } + } + fleet_ptr_->HeterPushSparseVars( + task, tid, + sparse_key_names_[tid], sparse_grad_names_[tid], table.emb_dim(), + &push_sparse_status_, use_cvm_, + dump_slot_, no_cvm_); + } + } + if (need_to_push_dense_) { + for (int i = 0; i < param_.program_config(0).push_dense_table_id_size(); + ++i) { + uint64_t tid = static_cast( + param_.program_config(0).push_dense_table_id(i)); + fleet_ptr_->PushDenseVarsAsync( + *(task->scope_), tid, dense_grad_names_[tid], &push_sparse_status_, + scale_datanorm_, task->cur_batch_); + } + VLOG(3) << "push dense gradient done."; + + // the following code should be more precise and clean + // TODO(guru4elephant) + int32_t tmp_push_dense_wait_times = -1; + static uint32_t push_dense_wait_times = + static_cast(tmp_push_dense_wait_times); + + if (push_dense_status_.size() >= push_dense_wait_times) { + for (auto& t : push_dense_status_) { + t.wait(); + } + push_dense_status_.resize(0); + } + + if (tmp_push_dense_wait_times == -1) { + push_dense_status_.resize(0); + } + } + + if (need_to_push_sparse_) { + VLOG(3) << "push sparse gradient done."; + int32_t tmp_push_sparse_wait_times = -1; + static uint32_t push_sparse_wait_times = + static_cast(tmp_push_sparse_wait_times); + if (push_sparse_status_.size() >= push_sparse_wait_times) { + for (auto& t : push_sparse_status_) { + t.wait(); + } + push_sparse_status_.resize(0); + } + + if (tmp_push_sparse_wait_times == -1) { + push_sparse_status_.resize(0); + } + } + + if (need_to_push_dense_) { + for (int i = 0; i < param_.program_config(0).push_dense_table_id_size(); + ++i) { + uint64_t tid = static_cast( + param_.program_config(0).push_dense_table_id(i)); + pull_dense_worker_->IncreaseThreadVersion(thread_id_, tid); + } + } + //if (need_dump_field_) { + // size_t batch_size = device_reader_->GetCurBatchSize(); + // std::vector ars(batch_size); + // for (auto& ar : ars) { + // ar.clear(); + // } + // auto& ins_id_vec = device_reader_->GetInsIdVec(); + // auto& ins_content_vec = device_reader_->GetInsContentVec(); + // for (size_t i = 0; i < ins_id_vec.size(); i++) { + // ars[i] += ins_id_vec[i]; + // ars[i] = ars[i] + "\t" + ins_content_vec[i]; + // } + // for (auto& field : dump_fields_) { + // Variable* var = thread_scope_->FindVar(field); + // if (var == nullptr) { + // continue; + // } + // LoDTensor* tensor = var->GetMutable(); + // if (!CheckValidOutput(tensor, batch_size)) { + // continue; + // } + // for (size_t i = 0; i < batch_size; ++i) { + // auto output_dim = tensor->dims()[1]; + // std::string output_dimstr = + // boost::lexical_cast(output_dim); + // ars[i] = ars[i] + "\t" + field + ":" + output_dimstr; + // auto bound = GetTensorBound(tensor, i); + // ars[i] += PrintLodTensor(tensor, bound.first, bound.second); + // } + // } + // // #pragma omp parallel for + // for (size_t i = 0; i < ars.size(); i++) { + // if (ars[i].length() == 0) { + // continue; + // } + // writer_ << ars[i]; + // } + // if (need_dump_param_ && thread_id_ == 0) { + // DumpParam(); + // } + //} + + //thread_scope_->DropKids(); + task->Update(); + } + else if (task->state_ == DONE) { + //std::cout << "wxx done taskid = " << task->taskid_ << std::endl; + object_pool_.Push(task); + PrintFetchVars(); + ++done_cnt; + //++batch_cnt; + break; + } + } + } + if (need_dump_field_) { + // writer_.Flush(); + } + if (copy_table_config_.need_copy()) { + CopySparseTable(); + CopyDenseTable(); + CopyDenseVars(); + } +} + +} // end namespace framework +} // end namespace paddle diff --git a/paddle/fluid/framework/multi_trainer.cc b/paddle/fluid/framework/multi_trainer.cc index 0faf96195403fa..41f3a192770b9a 100644 --- a/paddle/fluid/framework/multi_trainer.cc +++ b/paddle/fluid/framework/multi_trainer.cc @@ -140,6 +140,7 @@ void MultiTrainer::InitTrainerEnv(const ProgramDesc& main_program, workers_[i]->SetRootScope(root_scope_); workers_[i]->CreateDeviceResource(main_program); // Program workers_[i]->BindingDataFeedMemory(); + workers_[i]->CacheProgram(main_program); } } diff --git a/paddle/fluid/framework/trainer.h b/paddle/fluid/framework/trainer.h index e22d659a367df8..17f6ca77f43c96 100644 --- a/paddle/fluid/framework/trainer.h +++ b/paddle/fluid/framework/trainer.h @@ -111,6 +111,7 @@ class DistMultiTrainer : public MultiTrainer { virtual void InitDumpEnv(); virtual Scope* GetWorkerScope(int thread_id); virtual void DumpWork(int tid); + virtual void RegisterHeterCallback(); protected: std::shared_ptr pull_dense_worker_; diff --git a/python/paddle/fluid/device_worker.py b/python/paddle/fluid/device_worker.py index 59aaea0fb490a3..8f91092ff5bc00 100644 --- a/python/paddle/fluid/device_worker.py +++ b/python/paddle/fluid/device_worker.py @@ -216,7 +216,7 @@ def _gen_worker_desc(self, trainer_desc): dense_table_set.add(i) break - trainer_desc.device_worker_name = "DownpourWorker" + trainer_desc.device_worker_name = "HeterCpuWorker" pull_thread = trainer_desc.pull_dense_param pull_thread.device_num = trainer_desc.thread_num if opt_info.get("program_id_to_worker") is None: From 9d5ee2f6b3be30fca00c9558ff51e1dbc1ed91de Mon Sep 17 00:00:00 2001 From: Thunderbrook Date: Wed, 15 Apr 2020 19:50:44 +0800 Subject: [PATCH 2/8] xpu trainer --- paddle/fluid/framework/CMakeLists.txt | 14 +- paddle/fluid/framework/device_worker.h | 75 ++- paddle/fluid/framework/dist_multi_trainer.cc | 13 +- paddle/fluid/framework/downpour_worker.cc | 6 +- paddle/fluid/framework/fleet/CMakeLists.txt | 2 + paddle/fluid/framework/fleet/fleet_wrapper.cc | 83 ++- paddle/fluid/framework/fleet/fleet_wrapper.h | 21 +- paddle/fluid/framework/fleet/gloo_wrapper.cc | 79 ++- paddle/fluid/framework/fleet/heter_wrapper.cc | 298 +++++++++++ paddle/fluid/framework/fleet/heter_wrapper.h | 112 ++++ paddle/fluid/framework/heter_service.h | 68 +++ paddle/fluid/framework/heter_service.proto | 71 +++ paddle/fluid/framework/hetercpu_worker.cc | 398 ++++++++------- paddle/fluid/framework/heterxpu_trainer.cc | 482 ++++++++++++++++++ paddle/fluid/framework/pull_dense_worker.cc | 53 +- paddle/fluid/framework/trainer.h | 79 +++ paddle/fluid/framework/trainer_desc.proto | 1 + paddle/fluid/framework/trainer_factory.cc | 3 + paddle/fluid/pybind/CMakeLists.txt | 3 +- paddle/fluid/pybind/heter_wrapper_py.cc | 48 ++ paddle/fluid/pybind/heter_wrapper_py.h | 28 + paddle/fluid/pybind/pybind.cc | 2 + python/paddle/fluid/device_worker.py | 2 +- python/paddle/fluid/executor.py | 54 ++ .../fluid/incubate/fleet/base/fleet_base.py | 10 + .../fluid/incubate/fleet/base/role_maker.py | 143 +++++- .../fleet/parameter_server/pslib/__init__.py | 92 +++- .../pslib/optimizer_factory.py | 12 +- python/paddle/fluid/trainer_desc.py | 32 +- python/paddle/fluid/trainer_factory.py | 4 +- 30 files changed, 2012 insertions(+), 276 deletions(-) create mode 100644 paddle/fluid/framework/fleet/heter_wrapper.cc create mode 100644 paddle/fluid/framework/fleet/heter_wrapper.h create mode 100644 paddle/fluid/framework/heter_service.h create mode 100644 paddle/fluid/framework/heter_service.proto create mode 100644 paddle/fluid/framework/heterxpu_trainer.cc create mode 100644 paddle/fluid/pybind/heter_wrapper_py.cc create mode 100644 paddle/fluid/pybind/heter_wrapper_py.h diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index bb12944c2e15f8..21c841b47b38a8 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -27,6 +27,7 @@ add_subdirectory(fleet) add_subdirectory(io) #ddim lib proto_library(framework_proto SRCS framework.proto) +proto_library(heter_service_proto SRCS heter_service.proto) proto_library(data_feed_proto SRCS data_feed.proto) proto_library(trainer_desc_proto SRCS trainer_desc.proto DEPS framework_proto data_feed_proto) @@ -185,21 +186,24 @@ cc_library(executor_gc_helper SRCS executor_gc_helper.cc DEPS scope proto_desc o if(WITH_DISTRIBUTE) cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc + heterxpu_trainer.cc data_feed.cc device_worker.cc hogwild_worker.cc hetercpu_worker.cc downpour_worker.cc downpour_worker_opt.cc pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry - device_context scope framework_proto trainer_desc_proto glog fs shell fleet_wrapper box_wrapper lodtensor_printer + device_context scope framework_proto trainer_desc_proto glog fs shell + fleet_wrapper heter_wrapper box_wrapper lodtensor_printer lod_rank_table feed_fetch_method sendrecvop_rpc communicator collective_helper ${GLOB_DISTRIBUTE_DEPS} - graph_to_program_pass variable_helper data_feed_proto timer) + graph_to_program_pass variable_helper data_feed_proto heter_service_proto timer ) set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) else() cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc + heterxpu_trainer.cc data_feed.cc device_worker.cc hogwild_worker.cc hetercpu_worker.cc downpour_worker.cc downpour_worker_opt.cc pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry - device_context scope framework_proto data_feed_proto trainer_desc_proto glog - lod_rank_table fs shell fleet_wrapper box_wrapper lodtensor_printer feed_fetch_method - graph_to_program_pass variable_helper timer) + device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog + lod_rank_table fs shell fleet_wrapper heter_wrapper box_wrapper lodtensor_printer feed_fetch_method + graph_to_program_pass variable_helper timer pslib_brpc) cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op) endif() diff --git a/paddle/fluid/framework/device_worker.h b/paddle/fluid/framework/device_worker.h index 62a532abb5a074..3b5f57b9ade1ce 100644 --- a/paddle/fluid/framework/device_worker.h +++ b/paddle/fluid/framework/device_worker.h @@ -37,6 +37,7 @@ limitations under the License. */ #include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/port.h" #include "paddle/fluid/platform/timer.h" +#include "paddle/fluid/framework/heter_service.h" #if defined(PADDLE_WITH_NCCL) #include "paddle/fluid/platform/nccl_helper.h" @@ -51,6 +52,8 @@ bool CheckValidOutput(LoDTensor* tensor, size_t batch_size); class FleetWrapper; +class HeterWrapper; + #define SEC_LOG \ VLOG(3) << "[s" << section_id_ << "p" << pipeline_id_ << "t" << thread_id_ \ << "]: " @@ -59,6 +62,19 @@ class PullDenseWorker { public: virtual ~PullDenseWorker() {} virtual void Initialize(const TrainerDesc& param); + #ifdef PADDLE_WITH_CUDA + void AddStream(const cudaStream_t stream) { + copy_streams_.push_back(stream); + } + + void AddPlace(const paddle::platform::Place place) { + places_.push_back(place); + } + + void AddThreadScope(Scope* scope) { + thread_scopes_.push_back(scope); + } + #endif int Start(); void Stop(); void SetRootScope(Scope* scope) { root_scope_ = scope; } @@ -66,6 +82,7 @@ class PullDenseWorker { void ResetThreadVersion(uint64_t table_id); void Wait(std::vector<::std::future>* status_vec); void PullDense(bool force_update = false); + void CreatePinVar(); int GetThreadIdByScope(const Scope* scope); void SetThreadIdByScope(const Scope* scope, int tid); static std::shared_ptr GetInstance() { @@ -109,6 +126,12 @@ class PullDenseWorker { std::mutex mutex_for_mean_scale_; float total_batch_num_ = 0; std::unordered_map scope_to_thread_id_; + + #ifdef PADDLE_WITH_CUDA + std::vector copy_streams_; + std::vector places_; + std::vector thread_scopes_; + #endif }; // should incorporate different type of device @@ -141,6 +164,7 @@ class DeviceWorker { device_reader_->SetPlace(place); } virtual Scope* GetThreadScope() { return thread_scope_; } + virtual void GetXpuOpIndex() {} protected: Scope* root_scope_ = nullptr; @@ -301,6 +325,7 @@ enum HeterTaskState { PULL_SPARSE, OP_RUN, XPU, + OP_RUN_END, PUSH_GRAD, DONE }; @@ -312,17 +337,32 @@ class HeterTask { state_ = OP_RUN; } else if (state_ == OP_RUN) { - //state_ = XPU; + state_ = XPU; + //state_ = PUSH_GRAD; //state_ = PUSH_GRAD; - state_ = PUSH_GRAD; } else if (state_ == XPU) { + state_ = OP_RUN_END; + } + else if (state_ == OP_RUN_END) { state_ = PUSH_GRAD; } else if (state_ == PUSH_GRAD) { state_ = DONE; } } + void Reset() { + total_time = 0; + read_time = 0; + pack_time = 0; + pull_sparse_local_time = 0; + op_all_time = 0; + xpu_op_time = 0; + cpu_op_time = 0; + collect_label_time = 0; + fill_sparse_time = 0; + push_sparse_time = 0; + } void Show() { std::cout << "features size " << features_.size() << std::endl; for (size_t i = 0; i < features_.size(); ++i) { @@ -341,18 +381,30 @@ class HeterTask { std::map>> feature_values_; std::map>> feature_grads_; std::map> sparse_push_keys_; - double total_time; - double read_time; - double pack_time; - double pull_sparse_local_time; + double total_time{0}; + double read_time{0}; + double pack_time{0}; + double pull_sparse_local_time{0}; + double op_all_time{0}; + double xpu_op_time{0}; + double cpu_op_time{0}; + double collect_label_time{0}; + double fill_sparse_time{0}; + double push_sparse_time{0}; }; template class HeterObjectPool { public: + HeterObjectPool() {} + virtual ~HeterObjectPool() {}; std::shared_ptr Get() { std::lock_guard lock(mutex_); if (pool_.empty()) { + num_ += 1; + #ifdef PADDLE_WITH_CUDA + VLOG(0) << "pool construct size: " << num_; + #endif return std::make_shared(); } else { @@ -369,9 +421,13 @@ class HeterObjectPool { std::lock_guard lock(mutex_); return pool_.size(); } + std::shared_ptr& GetElement(int i) { + return pool_[i]; + } private: std::vector> pool_; std::mutex mutex_; + int num_{0}; }; @@ -535,15 +591,16 @@ class HeterCpuWorker : public HogwildWorker { virtual void SetNeedDump(bool need_dump_field); virtual void SetChannelWriter(ChannelObject* queue); virtual void SetWorkerNum(int num) { worker_num_ = num; } - virtual void CreateThreadParam(const ProgramDesc &main_program); virtual void Schedule(int taskid); virtual void JumpContext(std::shared_ptr task); virtual void CacheProgram(const ProgramDesc &main_program) { new(&program_) ProgramDesc(main_program); } + virtual void GetXpuOpIndex(); protected: std::shared_ptr fleet_ptr_; + std::shared_ptr heter_ptr_; std::shared_ptr pull_dense_worker_; void FillSparseValue(std::shared_ptr task, size_t table_id); void PushGradients(); @@ -555,7 +612,11 @@ class HeterCpuWorker : public HogwildWorker { void CopyDenseVars(); private: + //std::string recv_var; + int mpi_rank_; int worker_num_; + int xpu_begin_op_index_; + int xpu_end_op_index_; ProgramDesc program_; HeterObjectPool object_pool_; HeterList> run_queue_; diff --git a/paddle/fluid/framework/dist_multi_trainer.cc b/paddle/fluid/framework/dist_multi_trainer.cc index bced867561546c..a1e7d5b65df0b2 100644 --- a/paddle/fluid/framework/dist_multi_trainer.cc +++ b/paddle/fluid/framework/dist_multi_trainer.cc @@ -76,7 +76,7 @@ void DistMultiTrainer::RegisterHeterCallback() { auto fleet_ptr = FleetWrapper::GetInstance(); fleet_ptr->RegisterHeterCallback( [this](int worker, int taskid) { - workers_[worker]->Schedule(taskid); + //workers_[worker]->Schedule(taskid); } ); } @@ -156,7 +156,10 @@ void DistMultiTrainer::InitOtherEnv(const ProgramDesc &main_program) { InitDumpEnv(); } pull_dense_worker_->SetRootScope(root_scope_); - pull_dense_worker_->Start(); + //pull_dense_worker_->Start(); + for (int i = 0; i < thread_num_; ++i) { + workers_[i]->GetXpuOpIndex(); + } VLOG(3) << "init other env done."; } @@ -180,6 +183,10 @@ void DistMultiTrainer::Finalize() { for (auto &th : threads_) { th.join(); } + //if (mpi_rank_ == 0) { + // auto heter_ptr_ = HeterWrapper::GetInstance(); + // heter_ptr_->EndPass(root_scope_); + //} for (size_t i = 0; i < need_merge_var_names_.size(); i++) { Variable *root_var = root_scope_->FindVar(need_merge_var_names_[i]); if (root_var == nullptr) { @@ -214,7 +221,7 @@ void DistMultiTrainer::Finalize() { if (need_dump_field_) { FinalizeDumpEnv(); } - pull_dense_worker_->Stop(); + //pull_dense_worker_->Stop(); root_scope_->DropKids(); // flush local client push queue diff --git a/paddle/fluid/framework/downpour_worker.cc b/paddle/fluid/framework/downpour_worker.cc index b1a1b73a66e72d..4965e13b851a01 100644 --- a/paddle/fluid/framework/downpour_worker.cc +++ b/paddle/fluid/framework/downpour_worker.cc @@ -413,9 +413,9 @@ void DownpourWorker::CopyDenseTable() { if (copy_table_config_.dense_pull_after_copy()) { VLOG(3) << "dense pull after copy, table=" << dest_table; pull_dense_status.resize(0); - fleet_ptr_->PullDenseVarsAsync(*root_scope_, dest_table, - dense_value_names_[dest_table], - &pull_dense_status); + //fleet_ptr_->PullDenseVarsAsync(*root_scope_, dest_table, + // dense_value_names_[dest_table], + // &pull_dense_status); for (auto& t : pull_dense_status) { t.wait(); auto status = t.get(); diff --git a/paddle/fluid/framework/fleet/CMakeLists.txt b/paddle/fluid/framework/fleet/CMakeLists.txt index 6922f92c8f7a3a..118a4dde9a065a 100644 --- a/paddle/fluid/framework/fleet/CMakeLists.txt +++ b/paddle/fluid/framework/fleet/CMakeLists.txt @@ -18,5 +18,7 @@ if(WITH_GLOO) else() cc_library(gloo_wrapper SRCS gloo_wrapper.cc DEPS framework_proto variable_helper scope) endif(WITH_GLOO) + +cc_library(heter_wrapper SRCS heter_wrapper.cc) cc_test(test_fleet SRCS test_fleet.cc DEPS fleet_wrapper gloo_wrapper fs shell) diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.cc b/paddle/fluid/framework/fleet/fleet_wrapper.cc index c70588b18e0f9c..a04380ee554695 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.cc +++ b/paddle/fluid/framework/fleet/fleet_wrapper.cc @@ -203,15 +203,15 @@ void FleetWrapper::HeterPullSparseVars( auto status = pslib_ptr_->_worker_ptr->heter_pull_sparse(workerid, pull_result_ptr.data(), table_id, fea_keys.data(), fea_keys.size(), task->taskid_); pull_sparse_status.push_back(std::move(status)); - //for (auto& t : pull_sparse_status) { - // t.wait(); - // auto status = t.get(); - // if (status != 0) { - // LOG(ERROR) << "fleet pull sparse failed, status[" << status << "]"; - // sleep(sleep_seconds_before_fail_exit_); - // exit(-1); - // } - //} + for (auto& t : pull_sparse_status) { + t.wait(); + auto status = t.get(); + if (status != 0) { + LOG(ERROR) << "fleet pull sparse failed, status[" << status << "]"; + sleep(sleep_seconds_before_fail_exit_); + exit(-1); + } + } #endif } @@ -637,13 +637,18 @@ void FleetWrapper::PullSparseToTensorSync(const uint64_t table_id, int fea_dim, void FleetWrapper::PullDenseVarsAsync( const Scope& scope, const uint64_t tid, const std::vector& var_names, - std::vector<::std::future>* pull_dense_status) { + std::vector<::std::future>* pull_dense_status, + bool in_cpu) { #ifdef PADDLE_WITH_PSLIB auto& regions = _regions[tid]; regions.clear(); regions.resize(var_names.size()); for (auto i = 0u; i < var_names.size(); ++i) { - Variable* var = scope.FindVar(var_names[i]); + std::string varname = var_names[i]; + if (!in_cpu) { + varname = var_names[i] + "pin"; + } + Variable* var = scope.FindVar(varname); LoDTensor* tensor = var->GetMutable(); float* w = tensor->data(); paddle::ps::Region reg(w, tensor->numel()); @@ -701,6 +706,62 @@ void FleetWrapper::PushDenseVarsSync( Scope* scope, const uint64_t table_id, const std::vector& var_names) {} +#ifdef PADDLE_WITH_CUDA +void FleetWrapper::PushDenseVarsAsync( + const Scope& scope, const uint64_t table_id, + const std::vector& var_names, + std::vector<::std::future>* push_sparse_status, + float scale_datanorm, int batch_size, + const paddle::platform::Place& place, + cudaStream_t stream, + cudaEvent_t event) { +#ifdef PADDLE_WITH_PSLIB + std::vector regions; + for (auto& t : var_names) { + Variable* var = scope.FindVar(t); + LoDTensor* tensor = var->GetMutable(); + int count = tensor->numel(); + float* g_data = tensor->data(); + + Variable *pin_var = scope.FindVar(t + "pin"); + LoDTensor* pin_tensor = pin_var->GetMutable(); + float *pin_g = pin_tensor->mutable_data(tensor->dims(), platform::CUDAPinnedPlace()); + memory::Copy( + platform::CUDAPinnedPlace(), + pin_g, + boost::get(place), + g_data, sizeof(float) * count, stream); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(event, stream)); + cudaEventSynchronize(event); + + float* g = pin_g; + if (scale_datanorm >= 0) { + if (t.find(".batch_size@GRAD") != std::string::npos || + t.find(".batch_sum@GRAD") != std::string::npos) { + Eigen::Map mat(g, 1, count); + float scale = 1.0 / batch_size; + mat *= scale; + } else if (t.find(".batch_square_sum@GRAD") != std::string::npos) { + VLOG(3) << "epsilon: " << scale_datanorm; + for (int i = 0; i < count; ++i) { + g[i] = (g[i] - batch_size * scale_datanorm) / batch_size + + batch_size * scale_datanorm; + } + } + } + paddle::ps::Region reg(g, count); + regions.emplace_back(std::move(reg)); + } + + auto status = pslib_ptr_->_worker_ptr->push_dense(regions.data(), + regions.size(), table_id); + if (push_sparse_status) { + push_sparse_status->push_back(std::move(status)); + } +#endif +} + +#endif void FleetWrapper::PushDenseVarsAsync( const Scope& scope, const uint64_t table_id, const std::vector& var_names, diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.h b/paddle/fluid/framework/fleet/fleet_wrapper.h index 2d8ef11b479317..14e194ce489e31 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.h +++ b/paddle/fluid/framework/fleet/fleet_wrapper.h @@ -30,8 +30,9 @@ limitations under the License. */ #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/scope.h" -#include "paddle/fluid/framework/device_worker.h" #include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/framework/heter_service.h" +#include "paddle/fluid/framework/device_worker.h" #include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN @@ -136,7 +137,8 @@ class FleetWrapper { void PullDenseVarsAsync( const Scope& scope, const uint64_t table_id, const std::vector& var_names, - std::vector<::std::future>* pull_dense_status); + std::vector<::std::future>* pull_dense_status, + bool in_cpu); // push dense parameters(not gradients) to server in sync mode void PushDenseParamSync(const Scope& scope, const uint64_t table_id, @@ -145,12 +147,21 @@ class FleetWrapper { // Push dense variables to server in async mode // Param: scope, table_id, var_names, scale_datanorm, batch_size // Param: push_sparse_status + #ifdef PADDLE_WITH_CUDA + void PushDenseVarsAsync( + const Scope& scope, const uint64_t table_id, + const std::vector& var_names, + std::vector<::std::future>* push_sparse_status, + float scale_datanorm, int batch_size, + const paddle::platform::Place& place, + cudaStream_t stream, + cudaEvent_t event); + #endif void PushDenseVarsAsync( const Scope& scope, const uint64_t table_id, const std::vector& var_names, std::vector<::std::future>* push_sparse_status, float scale_datanorm, int batch_size); - // push dense variables to server in sync mode void PushDenseVarsSync(Scope* scope, const uint64_t table_id, const std::vector& var_names); @@ -292,7 +303,7 @@ class FleetWrapper { #ifdef PADDLE_WITH_PSLIB static std::shared_ptr pslib_ptr_; #endif - + private: static std::shared_ptr s_instance_; #ifdef PADDLE_WITH_PSLIB @@ -301,7 +312,7 @@ class FleetWrapper { size_t GetAbsoluteSum(size_t start, size_t end, size_t level, const framework::LoD& lod); - + protected: static bool is_initialized_; bool scale_sparse_gradient_with_batch_size_; diff --git a/paddle/fluid/framework/fleet/gloo_wrapper.cc b/paddle/fluid/framework/fleet/gloo_wrapper.cc index c599432ff190ae..aea63245cfb4a1 100644 --- a/paddle/fluid/framework/fleet/gloo_wrapper.cc +++ b/paddle/fluid/framework/fleet/gloo_wrapper.cc @@ -35,49 +35,86 @@ void HdfsStore::set(const std::string& key, const std::vector& data) { } int err_no = 0; for (int i = 1; i <= retry_times_; ++i) { + err_no = 0; std::shared_ptr fp = paddle::framework::fs_open_write(tmp, &err_no, ""); - if (err_no != 0) { - VLOG(0) << "fs_open_write failed, retry times " << i << " err no " - << err_no; - fp.reset(); - sleep(wait_sleep_ms_ / 1000); - continue; - } size_t write_count = fwrite_unlocked(data.data(), 1, data.size(), fp.get()); if (write_count != data.size()) { VLOG(0) << "fwrite_unlocked failed, retry times " << i << " write_count " << write_count << " data.size() " << data.size(); - fp.reset(); - sleep(2); - continue; + err_no = -1; } fp.reset(); - break; + if (err_no != 0) { + VLOG(0) << "fs_open_write failed, retry times " << i << " err no " + << err_no; + sleep(wait_sleep_ms_ / 1000); + paddle::framework::fs_remove(tmp); + if (i == retry_times_) { + VLOG(0) << "fs_open_write failed, retry times reaches limit"; + //PADDLE_THROW(platform::errors::PreconditionNotMet( + // "fs_open_write failed, retry times reaches" + // " limit ", + // retry_times_)); + } + } else { + break; + } } paddle::framework::fs_mv(tmp, path); #endif } +#ifdef PADDLE_WITH_GLOO +int retry_do_func(std::function func, uint32_t max_try_time, + uint32_t retry_interval_ms) { + for (uint32_t i = 0; i < max_try_time; ++i) { + if (func() == 0) { + return 0; + } +#ifdef _LINUX + usleep(retry_interval_ms * 1000); +#endif + } + return -1; +} +#endif + std::vector HdfsStore::get(const std::string& key) { auto path = ObjectPath(key); std::vector result; #ifdef PADDLE_WITH_GLOO // block until key is set wait({key}); - bool is_exists = paddle::framework::fs_exists(path); + int ret = retry_do_func( + [&path]() { return paddle::framework::fs_exists(path) ? 0 : -1; }, 5, + wait_sleep_ms_); + bool is_exists = (ret == 0); PADDLE_ENFORCE_EQ(is_exists, true, paddle::platform::errors::NotFound( "HdfsStore::get, path not exists: " + path)); - int err_no = 0; - std::shared_ptr fp = paddle::framework::fs_open_read(path, &err_no, ""); - char buffer = '\0'; - size_t read_count = 0; - while (fread(&buffer, 1, 1, fp.get()) == 1) { - ++read_count; - result.push_back(buffer); - } - VLOG(3) << "HdfsStore::get read_count " << read_count; + + int read_status = retry_do_func( + [&path, &result]() { + result.clear(); + int err_no = 0; + { + std::shared_ptr fp = + paddle::framework::fs_open_read(path, &err_no, ""); + char buffer = '\0'; + size_t read_count = 0; + while (fread(&buffer, 1, 1, fp.get()) == 1) { + ++read_count; + result.push_back(buffer); + } + VLOG(3) << "HdfsStore::get read_count " << read_count; + } + return err_no; + }, + 5, wait_sleep_ms_); + PADDLE_ENFORCE_EQ(read_status, 0, + paddle::platform::errors::Fatal( + "HdfsStore::get, path read faied: " + path)); #endif return result; } diff --git a/paddle/fluid/framework/fleet/heter_wrapper.cc b/paddle/fluid/framework/fleet/heter_wrapper.cc new file mode 100644 index 00000000000000..9636201b5ca548 --- /dev/null +++ b/paddle/fluid/framework/fleet/heter_wrapper.cc @@ -0,0 +1,298 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/fleet/heter_wrapper.h" +#include +#include +#include "paddle/fluid/framework/channel.h" +#include "paddle/fluid/framework/data_feed.h" +#include "paddle/fluid/framework/io/fs.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/platform/timer.h" + +namespace paddle { +namespace framework { + +std::shared_ptr HeterWrapper::s_instance_ = NULL; +bool HeterWrapper::is_initialized_ = false; + +void HeterWrapper::CreateClient2XpuConnection() { + brpc::ChannelOptions options; + options.protocol = "baidu_std"; + options.connection_type = "single"; + + options.timeout_ms = 2000000; + + xpu_channels_.resize(xpu_list_.size()); + for (size_t i = 0; i < xpu_list_.size(); ++i) { + VLOG(3) << "channel init: " << xpu_list_[i]; + xpu_channels_[i].reset(new brpc::Channel()); + if (xpu_channels_[i]->Init(xpu_list_[i].c_str(), "", &options) != 0) { + VLOG(0) << "server channel init fail"; + } + } +} + +void HeterWrapper::RegisterServiceHandler(int cmd, HeterServiceHandler func) { + service_.RegisterServiceHandler(cmd, func); +} + +void HeterWrapper::SetXpuList(const std::vector& xpu_list) { +#ifdef PADDLE_WITH_PSLIB + VLOG(3) << "Going to set xpu list"; + for (auto& x : xpu_list) { + xpu_list_.push_back(x); + VLOG(3) << "set xpu list: " << x << " size: " << xpu_list_.size(); + } +#endif +} + +void HeterWrapper::StartXpuService(const std::string& ip, uint32_t port) { + std::string ip_port = ip + ":" + std::to_string(port); + VLOG(3) << "xpu server starts at " << ip_port; + + server_.AddService(&service_, brpc::SERVER_DOESNT_OWN_SERVICE); + brpc::ServerOptions options; + + if (server_.Start(ip_port.c_str(), &options) != 0) { + VLOG(0) << "xpu server start fail"; + } +} + +//void HeterWrapper::SerializeToReq(const std::string& varname, Scope* scope, HeterRequest& request) { +// auto* req_var = request.mutable_vars(); + +void HeterWrapper::SerializeToReq(const std::string& varname, Scope* scope, VariableMessage* req_var) { + Variable* var = scope->FindVar(varname); + if (var == nullptr) { + return; + } + LoDTensor* tensor = var->GetMutable(); + req_var->set_varname(varname); + req_var->set_type(LOD_TENSOR); + req_var->set_data_type(static_cast(tensor->type())); + + for (auto& dim : framework::vectorize(tensor->dims())) { + req_var->add_dims(dim); + } + const framework::LoD lod = tensor->lod(); + if (lod.size() > 0) { + req_var->set_lod_level(lod.size()); + for (auto& each : lod) { + VariableMessage::LodData* lod_inner = req_var->add_lod(); + for (auto& d : each) { + lod_inner->add_lod_data(d); + } + } + } + + auto* req_data = req_var->mutable_data(); + req_data->clear(); + req_data->resize(tensor->numel() * SizeOfType(tensor->type())); + char* data_ptr = const_cast(req_data->data()); + + if (platform::is_cpu_place(tensor->place())) { + memcpy(data_ptr, tensor->data(), tensor->numel() * SizeOfType(tensor->type())); + } + #ifdef PADDLE_WITH_CUDA + else { + memory::Copy( + platform::CPUPlace(), + data_ptr, + boost::get(tensor->place()), + tensor->data(), + tensor->numel() * SizeOfType(tensor->type()), nullptr); + } + #endif +} + +//void HeterWrapper::DeSerializeToTensor(Scope* scope, const HeterRequest* request) { +void HeterWrapper::DeSerializeToTensor(Scope* scope, const VariableMessage& req_var, platform::Place place) { + //const VariableMessage& req_var = request->vars(); + auto* var = scope->FindVar(req_var.varname()); + auto* tensor = var->GetMutable(); + + std::vector vec_dim; + for (auto& x : req_var.dims()) { + vec_dim.push_back(x); + } + tensor->Resize(make_ddim(vec_dim)); + + LoD lod; + for (int i = 0; i < req_var.lod_level(); ++i) { + framework::Vector v; + for (int j = 0; j < req_var.lod(i).lod_data_size(); ++j) { + v.push_back(req_var.lod(i).lod_data(j)); + } + lod.push_back(v); + } + tensor->set_lod(lod); + + void* tensor_data = + tensor->mutable_data(place, ToVarType(req_var.data_type())); + + #ifdef PADDLE_WITH_CUDA + memory::Copy( + boost::get(place), + tensor_data, + platform::CPUPlace(), + req_var.data().data(), + tensor->numel() * SizeOfType(tensor->type()), nullptr); + #else + memcpy(tensor_data, req_var.data().data(), tensor->numel() * SizeOfType(tensor->type())); + #endif +} + +framework::proto::VarType::Type HeterWrapper::ToVarType( + VariableMessage::Type type) { + switch (type) { + case VariableMessage::FP32: + return framework::proto::VarType::FP32; // NOLINT + case VariableMessage::FP64: + return framework::proto::VarType::FP64; // NOLINT + case VariableMessage::INT32: + return framework::proto::VarType::INT32; // NOLINT + case VariableMessage::INT64: + return framework::proto::VarType::INT64; // NOLINT + case VariableMessage::BOOL: + return framework::proto::VarType::BOOL; // NOLINT + default: + PADDLE_THROW("Not support type %d", type); + } +} + +void HeterWrapper::StopXpuService(int num) { + HeterRequest request; + HeterResponse response; + brpc::Controller cntl; + request.set_cmd(2); + + //for (size_t i = 0; i < xpu_channels_.size(); ++i) { + HeterService_Stub stub(xpu_channels_[num].get()); + stub.service(&cntl, &request, &response, NULL); + if (cntl.Failed()) { + VLOG(0) << "call stop xpu service fail: " << cntl.ErrorText(); + } + else { + VLOG(3) << "call stop xpu service success"; + } + //} +} + +void HeterWrapper::EndPass(Scope* scope, int num) { + HeterRequest request; + HeterResponse response; + brpc::Controller cntl; + request.set_cmd(1); + + //for (size_t i = 0; i < xpu_channels_.size(); ++i) { + HeterService_Stub stub(xpu_channels_[num].get()); + stub.service(&cntl, &request, &response, NULL); + if (cntl.Failed()) { + VLOG(0) << "call end pass fail: " << cntl.ErrorText(); + } + else { + VLOG(3) << "call end pass success"; + for (int j = 0; j < response.vars_size(); ++j) { + DeSerializeToTensor(scope, response.vars(j), platform::CPUPlace()); + } + } + //} +} + +void HeterWrapper::CallRemoteXpu(std::shared_ptr task, HeterCpuWorker* worker, int mpi_rank) { + HeterRequest request; + request.set_cmd(0); + request.set_cur_batch(task->cur_batch_); + + OnHeterRpcDone *done = new OnHeterRpcDone([this, task, worker] (void* done) { + auto* closure = (OnHeterRpcDone*)done; + if (closure->cntl.Failed()) { + VLOG(0) << "call xpu fail: " << closure->cntl.ErrorText(); + } + else { + VLOG(3) << "call xpu success"; + } + //DeSerializeToTensor(task->scope_, closure->response.vars(), platform::CPUPlace()); + for (int i = 0; i < closure->response.vars_size(); ++i) { + DeSerializeToTensor(task->scope_, closure->response.vars(i), platform::CPUPlace()); + } + + worker->Schedule(task->taskid_); + }); + + std::vector varnames = {"click", "12345"}; + //varnames.push_back(send_var); + //if (send_var == "_generated_var_412") { + varnames.push_back("filter_by_instag_0.tmp_0"); + varnames.push_back("filter_by_instag_2.tmp_0"); + varnames.push_back("filter_by_instag_0.tmp_1"); + varnames.push_back("concat_1.tmp_0"); + //} + for (auto& varname : varnames) { + auto* req_var = request.add_vars(); + SerializeToReq(varname, task->scope_, req_var); + } + + int num = mpi_rank % xpu_channels_.size(); + HeterService_Stub stub(xpu_channels_[num].get()); + //stub.service(&cntl, &request, &response, brpc::NewCallback(&HeterWrapper::RpcCallBack, response, cntl, worker, task)); + stub.service(&done->cntl, &request, &done->response, done); + +} + +void HeterWrapper::CallRemoteXpuSync(std::shared_ptr task, HeterCpuWorker* worker) { + HeterRequest request; + HeterResponse response; + brpc::Controller cntl; + request.set_cmd(0); + request.set_cur_batch(task->cur_batch_); + + std::vector varnames = {"concat_1.tmp_0", "click", "12345"}; + for (auto& varname : varnames) { + auto* req_var = request.add_vars(); + SerializeToReq(varname, task->scope_, req_var); + } + + HeterService_Stub stub(xpu_channels_[0].get()); + stub.service(&cntl, &request, &response, NULL); + if (cntl.Failed()) { + VLOG(0) << "call xpu fail: " << cntl.ErrorText(); + } + else { + VLOG(3) << "call xpu success"; + for (int i = 0; i < response.vars_size(); ++i) { + DeSerializeToTensor(task->scope_, response.vars(i), platform::CPUPlace()); + } + } + +} + +} // end namespace framework +} // end namespace paddle diff --git a/paddle/fluid/framework/fleet/heter_wrapper.h b/paddle/fluid/framework/fleet/heter_wrapper.h new file mode 100644 index 00000000000000..1db3e18b4b8198 --- /dev/null +++ b/paddle/fluid/framework/fleet/heter_wrapper.h @@ -0,0 +1,112 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "paddle/fluid/framework/heter_service.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/device_worker.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/framework/variable_helper.h" +#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN + +namespace paddle { +namespace framework { + +typedef std::function HeterRpcCallbackFunc; + +class OnHeterRpcDone: public google::protobuf::Closure { + public: + OnHeterRpcDone(HeterRpcCallbackFunc func) : handler_(func) {} + virtual ~OnHeterRpcDone() {} + void Run() { + std::unique_ptr self_guard(this); + handler_(this); + } + + HeterRpcCallbackFunc handler_; + HeterResponse response; + brpc::Controller cntl; +}; + +class HeterWrapper { + public: + virtual ~HeterWrapper() { + server_.Stop(1000); + server_.Join(); + } + + HeterWrapper() { + } + + static void HeterRpcCallBack(HeterResponse* response, brpc::Controller* cntl, HeterCpuWorker* worker, std::shared_ptr task); + + void CreateClient2XpuConnection(); + + void RegisterServiceHandler(int cmd, HeterServiceHandler func); + + void StartXpuService(const std::string& ip, uint32_t port); + + void CallRemoteXpu(std::shared_ptr task, HeterCpuWorker* worker, int mpi_rank); + + void CallRemoteXpuSync(std::shared_ptr task, HeterCpuWorker* worker); + + void StopXpuService(int num); + + void EndPass(Scope* scope, int num); + + void SerializeToReq(const std::string& varname, Scope* scope, VariableMessage* req_var); + + framework::proto::VarType::Type ToVarType(VariableMessage::Type type); + + void DeSerializeToTensor(Scope* scope, const VariableMessage& req_var, platform::Place place); + + // HeterWrapper singleton + static std::shared_ptr GetInstance() { + if (NULL == s_instance_) { + s_instance_.reset(new paddle::framework::HeterWrapper()); + } + return s_instance_; + } + + std::vector& GetXpuList() { + return xpu_list_; + } + + void SetXpuList(const std::vector& xpu_list); + + private: + static std::shared_ptr s_instance_; + + protected: + std::vector> xpu_channels_; + brpc::Server server_; + HeterXpuService service_; + + static bool is_initialized_; + DISABLE_COPY_AND_ASSIGN(HeterWrapper); + std::vector xpu_list_; +}; + +} // end namespace framework +} // end namespace paddle diff --git a/paddle/fluid/framework/heter_service.h b/paddle/fluid/framework/heter_service.h new file mode 100644 index 00000000000000..e0e2c7d27ec0d1 --- /dev/null +++ b/paddle/fluid/framework/heter_service.h @@ -0,0 +1,68 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include // NOLINT +#include +#include // NOLINT +#include +#include "brpc/channel.h" +#include "brpc/controller.h" +#include "brpc/server.h" +#include "paddle/fluid/framework/heter_service.pb.h" + +namespace paddle { +namespace framework { + +typedef std::function HeterServiceHandler; + +class HeterXpuService : public HeterService { +public: + HeterXpuService() {} + virtual ~HeterXpuService() {} + + void service(::google::protobuf::RpcController* controller, + const HeterRequest* request, HeterResponse* response, + ::google::protobuf::Closure* done) { + brpc::ClosureGuard done_guard(done); + int ret = 0; + int cmd = request->cmd(); + auto itr = handler_map_.find(cmd); + if (itr == handler_map_.end()) { + + } + else { + ret = itr->second(request, response); + } + //response->set_err_code(0); + //response->set_err_msg(""); + if (ret != 0) { + //response->set_err_code(-1); + //response->set_err_msg("xpu service error"); + } + } + + void RegisterServiceHandler(int cmd, HeterServiceHandler func) { + VLOG(0) << "register heter service"; + handler_map_[cmd] = func; + } +private: + std::unordered_map handler_map_; +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/heter_service.proto b/paddle/fluid/framework/heter_service.proto new file mode 100644 index 00000000000000..519490a679c856 --- /dev/null +++ b/paddle/fluid/framework/heter_service.proto @@ -0,0 +1,71 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +syntax = "proto2"; +package paddle.framework; +option cc_generic_services = true; + +// It can be: LoDTensor、SelectedRows or NCCL_ID +enum VarType { + LOD_TENSOR = 0; + SELECTED_ROWS = 1; + NCCL_ID = 2; +} + +// VariableMessage is serialized paddle variable message. +// NOTICE(gongwb):don't modify this proto if you are not +// not familar with how we serialize in sendrecvop_utils.h +// and deserilize it in variable_response.h. +message VariableMessage { + enum Type { + // Pod Types + BOOL = 0; + INT16 = 1; + INT32 = 2; + INT64 = 3; + FP16 = 4; + FP32 = 5; + FP64 = 6; + } + + message LodData { repeated int64 lod_data = 1; } + optional string varname = 1; + // TODO(Yancey1989): reference framework::proto::VarDesc::VarType + optional VarType type = 2; + // bool persistable is not needed for sending. + // tensor info: + optional Type data_type = 3; + repeated int64 dims = 4; + + // lod details: + optional int64 lod_level = 5; + repeated LodData lod = 6; + // selected_rows height, aka. original dim0 + optional int64 slr_height = 7; + // tensor data + optional bytes data = 8; +} +message HeterRequest { + required int32 cmd = 1; + optional int32 cur_batch = 2; + repeated VariableMessage vars = 3; +}; + +message HeterResponse { + //optional VariableMessage vars = 1; + repeated VariableMessage vars = 1; +}; + +service HeterService { + rpc service(HeterRequest) returns (HeterResponse); +}; diff --git a/paddle/fluid/framework/hetercpu_worker.cc b/paddle/fluid/framework/hetercpu_worker.cc index fec2bccb38f611..5b373b6a75ab6a 100644 --- a/paddle/fluid/framework/hetercpu_worker.cc +++ b/paddle/fluid/framework/hetercpu_worker.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include "paddle/fluid/framework/device_worker.h" #include "paddle/fluid/framework/device_worker_factory.h" #include "paddle/fluid/framework/fleet/fleet_wrapper.h" +#include "paddle/fluid/framework/fleet/heter_wrapper.h" #include "paddle/fluid/platform/cpu_helper.h" #include "paddle/fluid/string/string_helper.h" @@ -27,10 +28,10 @@ namespace paddle { namespace framework { void HeterTask::PackTask(Scope* thread_scope, int taskid, DataFeed* reader, int cur_batch, const ProgramDesc& program) { - total_time = 0; - read_time = 0; - pack_time = 0; - pull_sparse_local_time = 0; + //total_time = 0; + //read_time = 0; + //pack_time = 0; + //pull_sparse_local_time = 0; taskid_ = taskid; auto &block = program.Block(0); if (!scope_) { @@ -58,8 +59,76 @@ void HeterTask::PackTask(Scope* thread_scope, int taskid, DataFeed* reader, int } +void HeterCpuWorker::GetXpuOpIndex() { + xpu_begin_op_index_ = xpu_end_op_index_ = -1; + for (size_t i = 0; i < ops_.size(); ++i) { + //if (!first && ops_[i]->Type() == "mul") { + // first = 1; + // xpu_begin_op_index_ = i; + // auto& in_map = ops_[i]->Inputs(); + // + // + // auto it = in_map.find("X"); + // if (it != in_map.end()) { + // for (auto& x : it->second) { + // send_var_ = x; + // } + // } + // + //} + //if (ops_[i]->Type() == "mul_grad") { + // xpu_end_op_index_ = i; + // //auto& out_map = ops_[i]->Outputs(); + // //auto it = out_map.find("X@GRAD"); + // //if (it != out_map.end()) { + // // for (auto& x : it->second) { + // // recv_var = x; + // // } + // //} + //} + auto& out_map = ops_[i]->Outputs(); + + { + auto it = out_map.find("Out"); + if (it != out_map.end()) { + for (auto& x : it->second) { + if (x == "concat_1.tmp_0") { + xpu_begin_op_index_ = i + 1; + } + } + } + } + + { + auto it = out_map.find("X@GRAD"); + if (it != out_map.end()) { + for (auto& x : it->second) { + if (x == "concat_1.tmp_0@GRAD") { + xpu_end_op_index_ = i; + } + } + } + } + + { + auto it = out_map.find("Out"); + if (it != out_map.end()) { + for (auto& x : it->second) { + if (x == "concat_1.tmp_0@GRAD") { + xpu_end_op_index_ = i; + } + } + } + } + } + if (xpu_end_op_index_ == -1) { + xpu_end_op_index_ = ops_.size() - 1; + } + VLOG(0) << "xpu begin: " << xpu_begin_op_index_ << " xpu end: " << xpu_end_op_index_; +} + void HeterCpuWorker::Schedule(int taskid) { - //std::cout << "wxx schedule " << taskid << std::endl; + VLOG(3) << "schedule " << taskid; auto task = wait_queue_.TryGet(taskid); if (task) { run_queue_.Put(task->taskid_, task); @@ -67,7 +136,7 @@ void HeterCpuWorker::Schedule(int taskid) { } void HeterCpuWorker::JumpContext(std::shared_ptr task) { - //std::cout << "wxx jump context " << task->taskid_ << std::endl; + VLOG(3) << "jump context " << task->taskid_; if (!(wait_queue_.TryPut(task->taskid_, task))) { run_queue_.Put(task->taskid_, task); } @@ -75,6 +144,7 @@ void HeterCpuWorker::JumpContext(std::shared_ptr task) { void HeterCpuWorker::Initialize(const TrainerDesc& desc) { param_ = desc.downpour_param(); + mpi_rank_ = desc.mpi_rank(); for (int i = 0; i < param_.sparse_table_size(); ++i) { uint64_t table_id = static_cast(param_.sparse_table(i).table_id()); @@ -120,6 +190,7 @@ void HeterCpuWorker::Initialize(const TrainerDesc& desc) { need_to_push_dense_ = param_.push_dense(); fleet_ptr_ = FleetWrapper::GetInstance(); + heter_ptr_ = HeterWrapper::GetInstance(); fetch_config_ = desc.fetch_config(); use_cvm_ = desc.use_cvm(); // for sparse value accessor, embedding only @@ -544,58 +615,6 @@ void HeterCpuWorker::CopyDenseTable() { } } -void HeterCpuWorker::CreateThreadParam(const ProgramDesc& program) { - #ifdef PADDLE_WITH_CUDA - auto dev_id = boost::get(place_).device; - platform::CUDADeviceGuard guard(dev_id); - auto &block = program.Block(0); - for (auto& var : block.AllVars()) { - if (var->Persistable()) { - auto name = var->Name(); - Variable* root_var = root_scope_->FindVar(name); - LoDTensor* root_tensor = root_var->GetMutable(); - auto *ptr = thread_scope_->Var(name); - InitializeVariable(ptr, proto::VarType::LOD_TENSOR); - LoDTensor* thread_tensor = ptr->GetMutable(); - -#define MemcpyCallback(cpp_type, proto_type) \ - do { \ - if (root_tensor->type() == proto_type) { \ - MemCpy(thread_tensor, root_tensor, place_, copy_stream_); \ - } \ - } while (0) - _ForEachDataType_(MemcpyCallback); - - } - } - PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(event_, copy_stream_)); - cudaEventSynchronize(event_); - #endif -} - -#ifdef PADDLE_WITH_CUDA -template -void HeterCpuWorker::MemCpy(LoDTensor *thread_tensor, LoDTensor *root_tensor, - const paddle::platform::Place& thread_place, - cudaStream_t stream) { - T* thread_ptr = thread_tensor->mutable_data(root_tensor->dims(), thread_place); - T* root_ptr = root_tensor->data(); - if (platform::is_cpu_place(root_tensor->place())) { - memory::Copy( - boost::get(thread_place), - thread_ptr, - platform::CPUPlace(), - root_ptr, sizeof(T) * root_tensor->numel(), stream); - } - else { - memory::Copy( - boost::get(thread_place), - thread_ptr, - boost::get(root_tensor->place()), - root_ptr, sizeof(T) * root_tensor->numel(), stream); - } -} -#endif void HeterCpuWorker::CopyDenseVars() { if (thread_id_ != 0) { @@ -660,13 +679,14 @@ void HeterCpuWorker::TrainFilesWithProfiler() { platform::Timer timeline; double total_time = 0.0; double read_time = 0.0; - //double pull_sparse_time = 0.0; - double collect_label_time = 0.0; - double fill_sparse_time = 0.0; - double push_sparse_time = 0.0; - double push_dense_time = 0.0; double pack_time = 0.0; double pull_sparse_local_time = 0.0; + double op_all_time = 0; + double xpu_op_time = 0; + double cpu_op_time = 0; + double collect_label_time = 0; + double fill_sparse_time = 0; + double push_sparse_time = 0; int batch_cnt = 0; int done_cnt = 0; @@ -695,6 +715,7 @@ void HeterCpuWorker::TrainFilesWithProfiler() { int taskid = batch_cnt * worker_num_ + thread_id_; timeline.Start(); task = object_pool_.Get(); + task->Reset(); task->PackTask(thread_scope_, taskid, device_reader_, cur_batch, program_); timeline.Pause(); task->read_time = tmp_read_time; @@ -721,17 +742,16 @@ void HeterCpuWorker::TrainFilesWithProfiler() { table.fea_dim(), sparse_value_names_[tid]); } task->Update(); - JumpContext(task); + //JumpContext(task); timeline.Pause(); task->pull_sparse_local_time += timeline.ElapsedSec(); task->total_time += timeline.ElapsedSec(); - break; } else if (task->state_ == OP_RUN) { - total_time += task->total_time; - read_time += task->read_time; - pack_time += task->pack_time; - pull_sparse_local_time += task->pull_sparse_local_time; + //total_time += task->total_time; + //read_time += task->read_time; + //pack_time += task->pack_time; + //pull_sparse_local_time += task->pull_sparse_local_time; for (int i = 0; i < param_.program_config(0).pull_sparse_table_id_size(); ++i) { uint64_t tid = static_cast( @@ -739,13 +759,13 @@ void HeterCpuWorker::TrainFilesWithProfiler() { timeline.Start(); CollectLabelInfo(task, i); timeline.Pause(); - collect_label_time += timeline.ElapsedSec(); - total_time += timeline.ElapsedSec(); + task->collect_label_time += timeline.ElapsedSec(); + task->total_time += timeline.ElapsedSec(); timeline.Start(); FillSparseValue(task, i); timeline.Pause(); - fill_sparse_time += timeline.ElapsedSec(); - total_time += timeline.ElapsedSec(); + task->fill_sparse_time += timeline.ElapsedSec(); + task->total_time += timeline.ElapsedSec(); auto nid_iter = std::find(sparse_value_names_[tid].begin(), sparse_value_names_[tid].end(), @@ -757,8 +777,45 @@ void HeterCpuWorker::TrainFilesWithProfiler() { VLOG(3) << "fill sparse value for all sparse table done."; // do computation here - int run_op_idx = 0; - for (auto& op : ops_) { + //int run_op_idx = 0; + timeline.Start(); + for (int i = 0; i < xpu_begin_op_index_; ++i) { + auto& op = ops_[i]; + bool need_skip = false; + for (auto t = 0u; t < skip_ops_.size(); ++t) { + if (op->Type().find(skip_ops_[t]) != std::string::npos) { + need_skip = true; + break; + } + } + if (!need_skip) { + //timeline.Start(); + op->Run(*(task->scope_), place_); + //timeline.Pause(); + //op_total_time[run_op_idx++] += timeline.ElapsedSec(); + //total_time += timeline.ElapsedSec(); + } + } + task->Update(); + timeline.Pause(); + task->cpu_op_time += timeline.ElapsedSec(); + task->total_time += timeline.ElapsedSec(); + } + else if (task->state_ == XPU) { + timeline.Start(); + VLOG(3) << "call remote xpu taskid = " << task->taskid_; + heter_ptr_->CallRemoteXpu(task, this, mpi_rank_); + task->Update(); + JumpContext(task); + timeline.Pause(); + task->xpu_op_time += timeline.ElapsedSec(); + task->total_time += timeline.ElapsedSec(); + break; + } + else if (task->state_ == OP_RUN_END) { + timeline.Start(); + for (size_t i = xpu_end_op_index_ + 1; i < ops_.size(); ++i) { + auto& op = ops_[i]; bool need_skip = false; for (auto t = 0u; t < skip_ops_.size(); ++t) { if (op->Type().find(skip_ops_[t]) != std::string::npos) { @@ -767,11 +824,7 @@ void HeterCpuWorker::TrainFilesWithProfiler() { } } if (!need_skip) { - timeline.Start(); op->Run(*(task->scope_), place_); - timeline.Pause(); - op_total_time[run_op_idx++] += timeline.ElapsedSec(); - total_time += timeline.ElapsedSec(); } } // check inf and nan @@ -790,6 +843,9 @@ void HeterCpuWorker::TrainFilesWithProfiler() { "Tensor %s contains NAN", var_name); } task->Update(); + timeline.Pause(); + task->cpu_op_time += timeline.ElapsedSec(); + task->total_time += timeline.ElapsedSec(); } else if (task->state_ == PUSH_GRAD) { if (need_to_push_sparse_) { @@ -812,40 +868,8 @@ void HeterCpuWorker::TrainFilesWithProfiler() { &push_sparse_status_, use_cvm_, dump_slot_, no_cvm_); timeline.Pause(); - push_sparse_time += timeline.ElapsedSec(); - total_time += timeline.ElapsedSec(); - } - } - if (need_to_push_dense_) { - timeline.Start(); - for (int i = 0; i < param_.program_config(0).push_dense_table_id_size(); - ++i) { - uint64_t tid = static_cast( - param_.program_config(0).push_dense_table_id(i)); - fleet_ptr_->PushDenseVarsAsync( - *(task->scope_), tid, dense_grad_names_[tid], &push_sparse_status_, - scale_datanorm_, task->cur_batch_); - } - timeline.Pause(); - push_dense_time += timeline.ElapsedSec(); - total_time += timeline.ElapsedSec(); - VLOG(3) << "push dense gradient done."; - - // the following code should be more precise and clean - // TODO(guru4elephant) - int32_t tmp_push_dense_wait_times = -1; - static uint32_t push_dense_wait_times = - static_cast(tmp_push_dense_wait_times); - - if (push_dense_status_.size() >= push_dense_wait_times) { - for (auto& t : push_dense_status_) { - t.wait(); - } - push_dense_status_.resize(0); - } - - if (tmp_push_dense_wait_times == -1) { - push_dense_status_.resize(0); + task->push_sparse_time += timeline.ElapsedSec(); + task->total_time += timeline.ElapsedSec(); } } @@ -866,14 +890,6 @@ void HeterCpuWorker::TrainFilesWithProfiler() { } } - if (need_to_push_dense_) { - for (int i = 0; i < param_.program_config(0).push_dense_table_id_size(); - ++i) { - uint64_t tid = static_cast( - param_.program_config(0).push_dense_table_id(i)); - pull_dense_worker_->IncreaseThreadVersion(thread_id_, tid); - } - } //thread_scope_->DropKids(); task->Update(); @@ -883,26 +899,38 @@ void HeterCpuWorker::TrainFilesWithProfiler() { ++done_cnt; total_inst += task->cur_batch_; object_pool_.Push(task); + + total_time += task->total_time; + read_time += task->read_time; + pack_time += task->pack_time; + pull_sparse_local_time += task->pull_sparse_local_time; + op_all_time += task->op_all_time; + xpu_op_time += task->xpu_op_time; + cpu_op_time += task->cpu_op_time; + collect_label_time += task->collect_label_time; + fill_sparse_time += task->fill_sparse_time; + push_sparse_time += task->push_sparse_time; //++batch_cnt; if (thread_id_ == 0) { // should be configured here if (done_cnt > 0 && done_cnt % 100 == 0) { - double op_sum_time = 0; - std::unordered_map op_to_time; - for (size_t i = 0; i < op_total_time.size(); ++i) { - fprintf(stderr, "op_name:[%zu][%s], op_mean_time:[%fs]\n", i, - op_name[i].c_str(), op_total_time[i] / done_cnt); - if (op_to_time.find(op_name[i]) == op_to_time.end()) { - op_to_time[op_name[i]] = 0.0; - } - op_to_time[op_name[i]] += op_total_time[i]; - op_sum_time += op_total_time[i]; - } - for (auto& i : op_to_time) { - fprintf(stderr, "op [%s] run total time: [%f]ms\n", i.first.c_str(), - i.second / done_cnt); - } - fprintf(stderr, "op run total time: %fs\n", op_sum_time / done_cnt); + //double op_sum_time = 0; + //std::unordered_map op_to_time; + //for (size_t i = 0; i < op_total_time.size(); ++i) { + // fprintf(stderr, "op_name:[%zu][%s], op_mean_time:[%fs]\n", i, + // op_name[i].c_str(), op_total_time[i] / done_cnt); + // if (op_to_time.find(op_name[i]) == op_to_time.end()) { + // op_to_time[op_name[i]] = 0.0; + // } + // op_to_time[op_name[i]] += op_total_time[i]; + // op_sum_time += op_total_time[i]; + //} + //for (auto& i : op_to_time) { + // fprintf(stderr, "op [%s] run total time: [%f]ms\n", i.first.c_str(), + // i.second / done_cnt); + //} + fprintf(stderr, "cpu op run total time: %fs\n", cpu_op_time / done_cnt); + fprintf(stderr, "xpu op run total time: %fs\n", xpu_op_time / done_cnt); fprintf(stderr, "pack task time: %fs\n", pack_time / done_cnt); fprintf(stderr, "train total time: %fs\n", total_time / done_cnt); fprintf(stderr, "pull sparse local time: %fs\n", @@ -911,12 +939,12 @@ void HeterCpuWorker::TrainFilesWithProfiler() { fill_sparse_time / done_cnt); fprintf(stderr, "push sparse time: %fs\n", push_sparse_time / done_cnt); - fprintf(stderr, "push dense time: %fs\n", push_dense_time / done_cnt); fprintf(stderr, "collect label time: %fs\n", collect_label_time / done_cnt); fprintf(stderr, "mean read time: %fs\n", read_time / done_cnt); fprintf(stderr, "IO percent: %f\n", read_time / total_time * 100); - fprintf(stderr, "op run percent: %f\n", op_sum_time / total_time * 100); + fprintf(stderr, "cpu op run percent: %f\n", cpu_op_time / total_time * 100); + fprintf(stderr, "xpu op run percent: %f\n", xpu_op_time / total_time * 100); fprintf(stderr, "pack task percent: %f\n", pack_time / total_time * 100); fprintf(stderr, "pull sparse local time percent: %f\n", pull_sparse_local_time / total_time * 100); @@ -926,8 +954,6 @@ void HeterCpuWorker::TrainFilesWithProfiler() { fill_sparse_time / total_time * 100); fprintf(stderr, "push sparse time percent: %f\n", push_sparse_time / total_time * 100); - fprintf(stderr, "push dense time percent: %f\n", - push_dense_time / total_time * 100); fprintf(stderr, "%6.2f instances/s\n", total_inst / total_time); } } @@ -950,7 +976,7 @@ void HeterCpuWorker::TrainFiles() { int done_cnt = 0; int cur_batch; wait_queue_.SetCap(3); - //while ((cur_batch = device_reader_->Next()) > 0) { + need_to_push_dense_ = false; while (1) { //if (copy_table_config_.need_copy()) { // if (copy_table_config_.sparse_copy_by_feasign()) { @@ -968,22 +994,12 @@ void HeterCpuWorker::TrainFiles() { //} std::shared_ptr task; - //std::cout << "wait_queue size:" << wait_queue_.Size() << " run_queue size:" << run_queue_.Size() << std::endl; - //std::cout << "object pool size: " << object_pool_.Size() << std::endl; - // while (wait_queue_.Size() > 10) { - // std::cout << "sleep 10ms" << std::endl; - // usleep(10000); - // } task = run_queue_.Get(); - //std::cout << "wxx begin " << std::endl; if (!task) { - //std::cout << "wxx new pack " << std::endl; cur_batch = device_reader_->Next(); - //std::cout << "wxx " << cur_batch << " " << wait_queue_.Empty() << std::endl; if (cur_batch <= 0) { if (batch_cnt == done_cnt) { - //std::cout << "wxx pass done " << std::endl; break; } else { @@ -992,15 +1008,14 @@ void HeterCpuWorker::TrainFiles() { } batch_cnt += 1; int taskid = batch_cnt * worker_num_ + thread_id_; - //std::cout << "taskid " << taskid << " " << batch_cnt << " " << worker_num_ << " " << thread_id_ << std::endl; task = object_pool_.Get(); + task->Reset(); task->PackTask(thread_scope_, taskid, device_reader_, cur_batch, program_); } - //task->Show(); for (;;) { // pull sparse here if (task->state_ == PULL_SPARSE) { - //std::cout << "wxx pull sparse taskid = " << task->taskid_ << std::endl; + VLOG(3) << "pull sparse taskid = " << task->taskid_; for (int i = 0; i < param_.program_config(0).pull_sparse_table_id_size(); ++i) { uint64_t tid = static_cast( @@ -1017,11 +1032,11 @@ void HeterCpuWorker::TrainFiles() { table.fea_dim(), sparse_value_names_[tid]); } task->Update(); - JumpContext(task); - break; + //JumpContext(task); + //break; } else if (task->state_ == OP_RUN) { - //std::cout << "wxx oprun taskid = " << task->taskid_ << std::endl; + VLOG(3) << "oprun taskid = " << task->taskid_; for (int i = 0; i < param_.program_config(0).pull_sparse_table_id_size(); ++i) { uint64_t tid = static_cast( @@ -1038,7 +1053,32 @@ void HeterCpuWorker::TrainFiles() { VLOG(3) << "fill sparse value for all sparse table done."; // do computation here - for (auto& op : ops_) { + for (int i = 0; i < xpu_begin_op_index_; ++i) { + auto& op = ops_[i]; + bool need_skip = false; + for (auto t = 0u; t < skip_ops_.size(); ++t) { + if (op->Type().find(skip_ops_[t]) != std::string::npos) { + need_skip = true; + break; + } + } + if (!need_skip) { + VLOG(3) << "run op: " << op->Type(); + op->Run(*(task->scope_), place_); + } + } + task->Update(); + } + else if (task->state_ == XPU) { + VLOG(3) << "call remote xpu taskid = " << task->taskid_; + heter_ptr_->CallRemoteXpu(task, this, mpi_rank_); + task->Update(); + JumpContext(task); + break; + } + else if (task->state_ == OP_RUN_END) { + for (size_t i = xpu_end_op_index_ + 1; i < ops_.size(); ++i) { + auto& op = ops_[i]; bool need_skip = false; for (auto t = 0u; t < skip_ops_.size(); ++t) { if (op->Type().find(skip_ops_[t]) != std::string::npos) { @@ -1068,7 +1108,7 @@ void HeterCpuWorker::TrainFiles() { task->Update(); } else if (task->state_ == PUSH_GRAD) { - //std::cout << "wxx push grad taskid = " << task->taskid_ << std::endl; + VLOG(3) << "push grad taskid = " << task->taskid_; if (need_to_push_sparse_) { // push gradients here for (int i = 0; i < param_.program_config(0).push_sparse_table_id_size(); @@ -1089,34 +1129,6 @@ void HeterCpuWorker::TrainFiles() { dump_slot_, no_cvm_); } } - if (need_to_push_dense_) { - for (int i = 0; i < param_.program_config(0).push_dense_table_id_size(); - ++i) { - uint64_t tid = static_cast( - param_.program_config(0).push_dense_table_id(i)); - fleet_ptr_->PushDenseVarsAsync( - *(task->scope_), tid, dense_grad_names_[tid], &push_sparse_status_, - scale_datanorm_, task->cur_batch_); - } - VLOG(3) << "push dense gradient done."; - - // the following code should be more precise and clean - // TODO(guru4elephant) - int32_t tmp_push_dense_wait_times = -1; - static uint32_t push_dense_wait_times = - static_cast(tmp_push_dense_wait_times); - - if (push_dense_status_.size() >= push_dense_wait_times) { - for (auto& t : push_dense_status_) { - t.wait(); - } - push_dense_status_.resize(0); - } - - if (tmp_push_dense_wait_times == -1) { - push_dense_status_.resize(0); - } - } if (need_to_push_sparse_) { VLOG(3) << "push sparse gradient done."; @@ -1135,14 +1147,6 @@ void HeterCpuWorker::TrainFiles() { } } - if (need_to_push_dense_) { - for (int i = 0; i < param_.program_config(0).push_dense_table_id_size(); - ++i) { - uint64_t tid = static_cast( - param_.program_config(0).push_dense_table_id(i)); - pull_dense_worker_->IncreaseThreadVersion(thread_id_, tid); - } - } //if (need_dump_field_) { // size_t batch_size = device_reader_->GetCurBatchSize(); // std::vector ars(batch_size); @@ -1189,7 +1193,7 @@ void HeterCpuWorker::TrainFiles() { task->Update(); } else if (task->state_ == DONE) { - //std::cout << "wxx done taskid = " << task->taskid_ << std::endl; + VLOG(3) << "done taskid = " << task->taskid_; object_pool_.Push(task); PrintFetchVars(); ++done_cnt; diff --git a/paddle/fluid/framework/heterxpu_trainer.cc b/paddle/fluid/framework/heterxpu_trainer.cc new file mode 100644 index 00000000000000..63142b5b9ab8fd --- /dev/null +++ b/paddle/fluid/framework/heterxpu_trainer.cc @@ -0,0 +1,482 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include +#include "io/fs.h" +#include "paddle/fluid/framework/data_feed_factory.h" +#include "paddle/fluid/framework/data_set.h" +#include "paddle/fluid/framework/device_worker_factory.h" +#include "paddle/fluid/framework/fleet/fleet_wrapper.h" +#include "paddle/fluid/framework/trainer.h" +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/cuda_device_guard.h" + +namespace paddle { +namespace framework { + +void HeterXpuTrainer::Initialize(const TrainerDesc &trainer_desc, + Dataset *dataset) { + srand((unsigned)time(NULL)); + param_ = trainer_desc.downpour_param(); + for (int i = 0; i < param_.dense_table_size(); ++i) { + uint64_t table_id = static_cast(param_.dense_table(i).table_id()); + auto table = param_.dense_table(i); + dense_grad_names_[table_id].resize(table.dense_grad_name_size()); + for (int j = 0; j < table.dense_grad_name_size(); ++j) { + dense_grad_names_[table_id][j] = table.dense_grad_name(j); + } + } + scale_datanorm_ = trainer_desc.scale_datanorm(); + int place_num = trainer_desc.worker_places_size(); + for (int i = 0; i < place_num; ++i) { + int num = trainer_desc.worker_places(i); + platform::CUDAPlace place = platform::CUDAPlace(num); + platform::CUDADeviceGuard guard(place.device); + cudaStream_t stream; + PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamCreate(&stream)); + copy_streams_.push_back(stream); + places_.push_back(place); + cudaEvent_t event; + PADDLE_ENFORCE(cudaEventCreateWithFlags(&event, cudaEventDisableTiming)); + events_.push_back(event); + } + + //thread_num_ = trainer_desc.thread_num(); + //SetDataset(dataset); + + //dump_fields_path_ = trainer_desc.dump_fields_path(); + //dump_converter_ = trainer_desc.dump_converter(); + //need_dump_field_ = false; + //if (trainer_desc.dump_fields_size() != 0 && dump_fields_path_ != "") { + // need_dump_field_ = true; + //} + //if (need_dump_field_) { + // auto &file_list = dataset->GetFileList(); + // if (file_list.size() == 0) { + // need_dump_field_ = false; + // } + //} + //mpi_rank_ = trainer_desc.mpi_rank(); + //mpi_size_ = trainer_desc.mpi_size(); + //dump_file_num_ = trainer_desc.dump_file_num(); + //const std::vector readers = + // dataset->GetReaders(); + //thread_num_ = readers.size(); + for (int i = 0; i < trainer_desc.downpour_param().stat_var_names_size(); + i++) { + need_merge_var_names_.push_back( + trainer_desc.downpour_param().stat_var_names(i)); + } + running_ = true; + VLOG(3) << "going to initialize pull dense worker"; + pull_dense_worker_ = PullDenseWorker::GetInstance(); + pull_dense_worker_->Initialize(trainer_desc); + VLOG(3) << "initialize pull dense worker"; + SetDebug(trainer_desc.debug()); + + fleet_ptr_ = FleetWrapper::GetInstance(); + heter_ptr_ = HeterWrapper::GetInstance(); + RegisterServiceHandler(); + //for (int i = 0; i < trainer_desc.worker_places_size(); ++i) { + // int num = trainer_desc.worker_places(i); + // platform::CUDAPlace place = platform::CUDAPlace(num); + // platform::CUDADeviceGuard guard(place.device); + // cudaStream_t stream; + // PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamCreate(&stream)); + // copy_streams_.push_back(stream); + // places_.push_back(place); + //} + +} + +void HeterXpuTrainer::CreateThreadParam(const ProgramDesc& program, int num) { + auto place = places_[num]; + Scope* scope = place_scopes_[num]; + auto stream = copy_streams_[num]; + auto event = events_[num]; + + auto dev_id = boost::get(place).device; + platform::CUDADeviceGuard guard(dev_id); + auto &block = program.Block(0); + for (auto& var : block.AllVars()) { + if (var->Persistable()) { + auto name = var->Name(); + Variable* root_var = root_scope_->FindVar(name); + LoDTensor* root_tensor = root_var->GetMutable(); + auto *ptr = scope->Var(name); + InitializeVariable(ptr, proto::VarType::LOD_TENSOR); + LoDTensor* thread_tensor = ptr->GetMutable(); + +#define HeterMemcpyFunc(cpp_type, proto_type) \ + do { \ + if (root_tensor->type() == proto_type) { \ + HeterMemCpy(thread_tensor, root_tensor, place, stream); \ + } \ + } while (0) + _ForEachDataType_(HeterMemcpyFunc); + + } + } + PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(event, stream)); + cudaEventSynchronize(event); +} + +template +void HeterXpuTrainer::HeterMemCpy(LoDTensor *thread_tensor, LoDTensor *root_tensor, + const paddle::platform::Place& thread_place, + cudaStream_t stream) { + T* thread_ptr = thread_tensor->mutable_data(root_tensor->dims(), thread_place); + T* root_ptr = root_tensor->data(); + if (platform::is_cpu_place(root_tensor->place())) { + memory::Copy( + boost::get(thread_place), + thread_ptr, + platform::CPUPlace(), + root_ptr, sizeof(T) * root_tensor->numel(), stream); + } + else { + memory::Copy( + boost::get(thread_place), + thread_ptr, + boost::get(root_tensor->place()), + root_ptr, sizeof(T) * root_tensor->numel(), stream); + } +} + +void HeterXpuTrainer::DumpWork(int tid) { +} + +void HeterXpuTrainer::InitTrainerEnv(const ProgramDesc &main_program, + const platform::Place &place) { + CacheProgram(main_program); + place_ = place; + auto& profiler = paddle::ps::CostProfiler::instance(); + profiler.register_profiler("xpu_service_run_task"); +} + +void HeterXpuTrainer::InitOtherEnv(const ProgramDesc &main_program) { + auto &block = main_program.Block(0); + + pull_dense_worker_->SetRootScope(root_scope_); + pull_dense_worker_->CreatePinVar(); + + for (size_t i = 0; i < places_.size(); ++i) { + Scope* scope = &(root_scope_->NewScope()); + //for (auto &var : block.AllVars()) { + // if (var->Persistable()) { + // auto *ptr = scope->Var(var->Name()); + // InitializeVariable(ptr, var->GetType()); + // } + //} + place_scopes_.push_back(scope); + CreateThreadParam(main_program, i); + pull_dense_worker_->AddThreadScope(scope); + pull_dense_worker_->AddPlace(places_[i]); + pull_dense_worker_->AddStream(copy_streams_[i]); + } + + pull_dense_worker_->Start(); + for (auto& stream : copy_streams_) { + cudaStreamSynchronize(stream); + } + op_names_.clear(); + for (auto &op_desc : block.AllOps()) { + std::unique_ptr local_op = OpRegistry::CreateOp(*op_desc); + op_names_.push_back(op_desc->Type()); + OperatorBase *local_op_ptr = local_op.release(); + ops_.push_back(local_op_ptr); + continue; + } + + xpu_begin_op_index_ = xpu_end_op_index_ = -1; + for (size_t i = 0; i < ops_.size(); ++i) { + //if (!first && ops_[i]->Type() == "mul") { + // first = 1; + // xpu_begin_op_index_ = i; + // auto& in_map = ops_[i]->Inputs(); + // + // + // auto it = in_map.find("X"); + // if (it != in_map.end()) { + // for (auto& x : it->second) { + // send_var = x; + // } + // } + //} + //if (ops_[i]->Type() == "mul_grad") { + // xpu_end_op_index_ = i; + // auto& out_map = ops_[i]->Outputs(); + // auto it = out_map.find("X@GRAD"); + // if (it != out_map.end()) { + // for (auto& x : it->second) { + // recv_var_ = x; + // } + // } + //} + auto& out_map = ops_[i]->Outputs(); + + { + auto it = out_map.find("Out"); + if (it != out_map.end()) { + for (auto& x : it->second) { + if (x == "concat_1.tmp_0") { + xpu_begin_op_index_ = i + 1; + } + } + } + } + + { + auto it = out_map.find("X@GRAD"); + if (it != out_map.end()) { + for (auto& x : it->second) { + if (x == "concat_1.tmp_0@GRAD") { + xpu_end_op_index_ = i; + } + } + } + } + + { + auto it = out_map.find("Out"); + if (it != out_map.end()) { + for (auto& x : it->second) { + if (x == "concat_1.tmp_0@GRAD") { + xpu_end_op_index_ = i; + } + } + } + } + } + + if (xpu_end_op_index_ == -1) { + xpu_end_op_index_ = ops_.size() - 1; + } + + VLOG(0) << "xpu begin: " << xpu_begin_op_index_ << " xpu end: " << xpu_end_op_index_; + VLOG(3) << "init other env done."; +} + +void HeterXpuTrainer::Run() { + //for (int thidx = 0; thidx < thread_num_; ++thidx) { + // if (!debug_) { + // threads_.push_back( + // std::thread(&DeviceWorker::TrainFiles, workers_[thidx].get())); + // } else { + // threads_.push_back(std::thread(&DeviceWorker::TrainFilesWithProfiler, + // workers_[thidx].get())); + // } + //} +} + +int HeterXpuTrainer::EndPass(const HeterRequest* request, HeterResponse* response) { + //int scope_num = object_pool_.Size(); + for (size_t i = 0; i < need_merge_var_names_.size(); i++) { + Variable *root_var = root_scope_->FindVar(need_merge_var_names_[i]); + if (root_var == nullptr) { + continue; + } + LoDTensor *root_tensor = root_var->GetMutable(); + + for (size_t j = 0; j < place_scopes_.size(); j++) { + Scope *cur_thread_scope = place_scopes_[j]; + Variable *thread_var = + cur_thread_scope->FindVar(need_merge_var_names_[i]); + if (thread_var == nullptr) { + continue; + } + LoDTensor *thread_tensor = thread_var->GetMutable(); +// if (root_tensor->numel() != thread_tensor->numel()) { +// continue; +// } +#define MergeCallback(cpp_type, proto_type) \ + do { \ + if (root_tensor->type() == proto_type) { \ + if (thread_tensor->type() != proto_type) { \ + VLOG(0) << "Error: thread id=" << j << ", need_merge_var_names_[" << i \ + << "] " << need_merge_var_names_[i] \ + << ", root tensor type=" << root_tensor->type() \ + << ", thread tensor type=" << thread_tensor->type(); \ + exit(-1); \ + } \ + MergeToRootScope(root_tensor, thread_tensor); \ + } \ + } while (0) + _ForEachDataType_(MergeCallback); + + if (platform::is_gpu_place(thread_tensor->place())) { + auto dev_id = boost::get(thread_tensor->place()).device; + platform::CUDADeviceGuard guard(dev_id); + cudaMemset(thread_tensor->data(), 0, thread_tensor->numel() * SizeOfType(thread_tensor->type())); + } + else { + memset(thread_tensor->data(), 0, thread_tensor->numel() * SizeOfType(thread_tensor->type())); + } + } + + auto* merge_var = response->add_vars(); + heter_ptr_->SerializeToReq(need_merge_var_names_[i], root_scope_, merge_var); + if (platform::is_gpu_place(root_tensor->place())) { + auto dev_id = boost::get(root_tensor->place()).device; + platform::CUDADeviceGuard guard(dev_id); + cudaMemset(root_tensor->data(), 0, root_tensor->numel() * SizeOfType(root_tensor->type())); + } + else { + memset(root_tensor->data(), 0, root_tensor->numel() * SizeOfType(root_tensor->type())); + } + } + return 0; +} + + + +template +void HeterXpuTrainer::MergeToRootScope(LoDTensor *root_tensor, + LoDTensor *tensor) { + LoDTensor tmp_root; + TensorCopy(*root_tensor, platform::CPUPlace(), &tmp_root); + T *tmp_root_data = tmp_root.data(); + + LoDTensor tmp_tensor; + TensorCopy(*tensor, platform::CPUPlace(), &tmp_tensor); + T *data = tmp_tensor.data(); + for (int i = 0; i < tmp_tensor.numel(); i++) { + tmp_root_data[i] += data[i]; + } + TensorCopy(tmp_root, root_tensor->place(), root_tensor); +} + +int HeterXpuTrainer::StopService(const HeterRequest* request, HeterResponse* response) { + std::unique_lock lock(mutex_); + running_ = false; + cond_.notify_one(); + return 0; +} + +int HeterXpuTrainer::RunTask(const HeterRequest* request, HeterResponse* response) { + auto timer = std::make_shared("xpu_service_run_task"); + std::shared_ptr context = object_pool_.Get(); + + if (!context->scope_) { + int num = rand() % places_.size(); + context->place_num_ = num; + auto place = places_[num]; + context->scope_ = &(place_scopes_[num]->NewScope()); + auto &block = program_.Block(0); + for (auto &var : block.AllVars()) { + if (!var->Persistable()) { + auto *ptr = context->scope_->Var(var->Name()); + InitializeVariable(ptr, var->GetType()); + } + } + for (auto& v : dense_grad_names_) { + for (auto& name : v.second) { + auto *ptr = context->scope_->Var(name + "pin"); + InitializeVariable(ptr, proto::VarType::LOD_TENSOR); + } + } + for (auto &op_desc : block.AllOps()) { + std::unique_ptr local_op = OpRegistry::CreateOp(*op_desc); + OperatorBase *local_op_ptr = local_op.release(); + (context->ops_).push_back(local_op_ptr); + } + + auto dev_id = boost::get(place).device; + platform::CUDADeviceGuard guard(dev_id); + PADDLE_ENFORCE(cudaEventCreateWithFlags(&context->event_, cudaEventDisableTiming)); + } + + context->Reset(); + auto place = places_[context->place_num_]; + for (int i = 0; i < request->vars_size(); ++i) { + heter_ptr_->DeSerializeToTensor(context->scope_, request->vars(i), place); + } + + for (int i = xpu_begin_op_index_; i <= xpu_end_op_index_; ++i) { + auto& op = (context->ops_)[i]; + op->Run(*(context->scope_), place); + } + auto* dev_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(place)); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(context->event_, dev_ctx->stream())); + //cudaEventSynchronize(context->event_); + while (cudaEventQuery(context->event_) != cudaSuccess) { + VLOG(3) << "wait for kernel"; + bthread_yield(); + } + + std::string varname = "concat_1.tmp_0@GRAD"; + + auto* res_var = response->add_vars(); + heter_ptr_->SerializeToReq(varname, context->scope_, res_var); + + for (int i = 0; i < param_.program_config(0).push_dense_table_id_size(); + ++i) { + uint64_t tid = static_cast( + param_.program_config(0).push_dense_table_id(i)); + fleet_ptr_->PushDenseVarsAsync( + *(context->scope_), tid, dense_grad_names_[tid], &(context->push_dense_status_), + scale_datanorm_, request->cur_batch(), places_[context->place_num_], + copy_streams_[context->place_num_], context->event_); + } + + for (int i = 0; i < param_.program_config(0).push_dense_table_id_size(); + ++i) { + uint64_t tid = static_cast( + param_.program_config(0).push_dense_table_id(i)); + pull_dense_worker_->IncreaseThreadVersion(0, tid); + } + VLOG(3) << "push dense gradient done."; + context->scope_->DropKids(); + object_pool_.Push(context); + VLOG(0) << "pool size " << object_pool_.Size(); + return 0; +} + +void HeterXpuTrainer::RegisterServiceHandler() { + heter_ptr_->RegisterServiceHandler(0, + [this](const HeterRequest* request, HeterResponse* response) -> int { + return this->RunTask(request, response); + }); + heter_ptr_->RegisterServiceHandler(1, + [this](const HeterRequest* request, HeterResponse* response) -> int { + return this->EndPass(request, response); + }); + heter_ptr_->RegisterServiceHandler(2, + [this](const HeterRequest* request, HeterResponse* response) -> int { + return this->StopService(request, response); + }); +} + +Scope* HeterXpuTrainer::GetWorkerScope(int thread_id) { + return nullptr; +} + +void HeterXpuTrainer::Finalize() { + //for (auto &th : threads_) { + // th.join(); + //} + std::unique_lock lock(mutex_); + cond_.wait(lock, [this] { return !running_; }); + sleep(3); + pull_dense_worker_->Stop(); + root_scope_->DropKids(); +} + +} // namespace framework +} // namespace paddle +#endif diff --git a/paddle/fluid/framework/pull_dense_worker.cc b/paddle/fluid/framework/pull_dense_worker.cc index 8ae479cb19a28a..7b15205569eeb3 100644 --- a/paddle/fluid/framework/pull_dense_worker.cc +++ b/paddle/fluid/framework/pull_dense_worker.cc @@ -56,6 +56,26 @@ void PullDenseWorker::Initialize(const TrainerDesc& param) { current_version_[tid] = 0; } fleet_ptr_ = FleetWrapper::GetInstance(); + #ifdef PADDLE_WITH_CUDA + copy_streams_.clear(); + places_.clear(); + thread_scopes_.clear(); + #endif +} + +void PullDenseWorker::CreatePinVar() { + #ifdef PADDLE_WITH_CUDA + for (auto& v : dense_value_names_) { + for (auto& name : v.second) { + Variable* var = root_scope_->FindVar(name); + LoDTensor* tensor = var->GetMutable(); + auto *ptr = root_scope_->Var(name + "pin"); + InitializeVariable(ptr, proto::VarType::LOD_TENSOR); + LoDTensor* pin_tensor = ptr->GetMutable(); + pin_tensor->mutable_data(tensor->dims(), platform::CUDAPinnedPlace()); + } + } + #endif } void PullDenseWorker::Wait(std::vector<::std::future>* status_vec) { @@ -75,6 +95,29 @@ void PullDenseWorker::Wait(std::vector<::std::future>* status_vec) { exit(-1); } status_vec->resize(0); + #ifdef PADDLE_WITH_CUDA + + for (size_t i = 0; i < places_.size(); ++i) { + + for (auto& v : dense_value_names_) { + for (auto& name : v.second) { + + Variable* pin_var = root_scope_->FindVar(name + "pin"); + LoDTensor* pin_tensor = pin_var->GetMutable(); + float* pin_w = pin_tensor->data(); + Variable* var = thread_scopes_[i]->FindVar(name); + LoDTensor* tensor = var->GetMutable(); + float* w = tensor->data(); + memory::Copy( + boost::get(places_[i]), + w, + platform::CUDAPinnedPlace(), + pin_w, sizeof(float) * tensor->numel(), + copy_streams_[i]); + } + } + } + #endif } void PullDenseWorker::Stop() { @@ -91,8 +134,16 @@ void PullDenseWorker::PullDense(bool force_update) { uint64_t tid = static_cast( dwp_param_.program_config(0).pull_dense_table_id(i)); if (force_update || CheckUpdateParam(tid)) { + #ifdef PADDLE_WITH_CUDA + + + VLOG(3) << "pull dense " << force_update << " " << tid; + fleet_ptr_->PullDenseVarsAsync(*root_scope_, tid, dense_value_names_[tid], + &pull_dense_status_, false); + #else fleet_ptr_->PullDenseVarsAsync(*root_scope_, tid, dense_value_names_[tid], - &pull_dense_status_); + &pull_dense_status_, true); + #endif ResetThreadVersion(tid); } } diff --git a/paddle/fluid/framework/trainer.h b/paddle/fluid/framework/trainer.h index 17f6ca77f43c96..7e6f227347d4ec 100644 --- a/paddle/fluid/framework/trainer.h +++ b/paddle/fluid/framework/trainer.h @@ -31,6 +31,9 @@ limitations under the License. */ #include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/operators/reader/blocking_queue.h" #include "paddle/fluid/platform/port.h" +#include "paddle/fluid/framework/fleet/heter_wrapper.h" +#include "paddle/fluid/framework/heter_service.h" +#include namespace paddle { namespace framework { @@ -117,6 +120,82 @@ class DistMultiTrainer : public MultiTrainer { std::shared_ptr pull_dense_worker_; }; +#ifdef PADDLE_WITH_CUDA +class HeterServiceContext { +public: + HeterServiceContext() {} + virtual ~HeterServiceContext() { + for (OperatorBase* op : ops_) { + delete op; + } + std::vector().swap(ops_); + } + void Reset() { + push_dense_status_.clear(); + } + int place_num_; + Scope* scope_{nullptr}; + cudaEvent_t event_; + std::vector ops_; + std::vector<::std::future> push_dense_status_; +}; + +class HeterXpuTrainer : public TrainerBase { + public: + HeterXpuTrainer() {} + virtual ~HeterXpuTrainer() { + for (OperatorBase* op : ops_) { + delete op; + } + std::vector().swap(ops_); + } + virtual void Initialize(const TrainerDesc& trainer_desc, Dataset* data_set); + virtual void InitTrainerEnv(const ProgramDesc& main_program, + const platform::Place& place); + virtual void InitOtherEnv(const ProgramDesc& main_program); + virtual void Run(); + virtual void Finalize(); + virtual void DumpWork(int tid); + virtual void RegisterServiceHandler(); + virtual int RunTask(const HeterRequest* request, HeterResponse* response); + virtual Scope* GetWorkerScope(int thread_id); + virtual void CacheProgram(const ProgramDesc &main_program) { + new(&program_) ProgramDesc(main_program); + } + template + void HeterMemCpy(LoDTensor* tensor, LoDTensor* root_tensor, + const paddle::platform::Place& thread_place, + cudaStream_t stream); + void CreateThreadParam(const ProgramDesc& program, int num); + template + void MergeToRootScope(LoDTensor* root_tensor, LoDTensor* thread_tensor); + int EndPass(const HeterRequest* request, HeterResponse* response); + int StopService(const HeterRequest* request, HeterResponse* response); + protected: + DownpourWorkerParameter param_; + std::map> dense_grad_names_; + std::vector need_merge_var_names_; + float scale_datanorm_; + int xpu_begin_op_index_; + int xpu_end_op_index_; + bool running_; + paddle::platform::Place place_; + std::mutex mutex_; + ProgramDesc program_; + std::condition_variable cond_; + std::shared_ptr fleet_ptr_; + std::shared_ptr heter_ptr_; + std::shared_ptr pull_dense_worker_; + std::vector ops_; + std::vector op_names_; + std::vector place_scopes_; + HeterObjectPool object_pool_; + std::vector copy_streams_; + std::vector places_; + std::vector events_; +}; +#endif + #if defined(PADDLE_WITH_NCCL) class PipelineTrainer : public TrainerBase { public: diff --git a/paddle/fluid/framework/trainer_desc.proto b/paddle/fluid/framework/trainer_desc.proto index f442063313f033..420cab8ca8e1ca 100644 --- a/paddle/fluid/framework/trainer_desc.proto +++ b/paddle/fluid/framework/trainer_desc.proto @@ -49,6 +49,7 @@ message TrainerDesc { optional bool no_cvm = 21 [ default = false ]; optional bool thread_barrier = 22; repeated string loss_names = 23; + repeated int32 worker_places = 24; // device worker parameters optional HogwildWorkerParameter hogwild_param = 101; diff --git a/paddle/fluid/framework/trainer_factory.cc b/paddle/fluid/framework/trainer_factory.cc index 23cfa11d4c9b2e..f6fd1b0d8476c1 100644 --- a/paddle/fluid/framework/trainer_factory.cc +++ b/paddle/fluid/framework/trainer_factory.cc @@ -63,6 +63,9 @@ std::shared_ptr TrainerFactory::CreateTrainer( REGISTER_TRAINER_CLASS(MultiTrainer); REGISTER_TRAINER_CLASS(DistMultiTrainer); +#ifdef PADDLE_WITH_CUDA +REGISTER_TRAINER_CLASS(HeterXpuTrainer); +#endif #if defined(PADDLE_WITH_NCCL) REGISTER_TRAINER_CLASS(PipelineTrainer); #endif diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index 0fad32d160fd38..c5366e096af0ba 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -1,7 +1,7 @@ set(PYBIND_DEPS pybind python proto_desc memory executor fleet_wrapper box_wrapper prune feed_fetch_method pass_builder parallel_executor profiler layer tracer engine scope_pool analysis_predictor imperative_profiler imperative_flag save_load_util dlpack_tensor device_context - gloo_wrapper infer_io_utils) + gloo_wrapper infer_io_utils heter_wrapper) if (WITH_NCCL) set(PYBIND_DEPS ${PYBIND_DEPS} nccl_wrapper) @@ -31,6 +31,7 @@ set(PYBIND_SRCS global_value_getter_setter.cc reader_py.cc fleet_wrapper_py.cc + heter_wrapper_py.cc gloo_wrapper_py.cc box_helper_py.cc data_set_py.cc diff --git a/paddle/fluid/pybind/heter_wrapper_py.cc b/paddle/fluid/pybind/heter_wrapper_py.cc new file mode 100644 index 00000000000000..c8b6d6f2384801 --- /dev/null +++ b/paddle/fluid/pybind/heter_wrapper_py.cc @@ -0,0 +1,48 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include + +#ifdef _POSIX_C_SOURCE +#undef _POSIX_C_SOURCE +#endif + +#ifdef _XOPEN_SOURCE +#undef _XOPEN_SOURCE +#endif + +#include +#include + +#include "google/protobuf/io/zero_copy_stream_impl.h" +#include "google/protobuf/text_format.h" +#include "paddle/fluid/framework/fleet/heter_wrapper.h" +#include "paddle/fluid/pybind/heter_wrapper_py.h" + +namespace py = pybind11; + +namespace paddle { +namespace pybind { +void BindHeterWrapper(py::module* m) { + py::class_>(*m, "Heter") + .def(py::init([]() { + return framework::HeterWrapper::GetInstance(); + })) + .def("create_client2xpu_connection", &framework::HeterWrapper::CreateClient2XpuConnection) + .def("set_xpu_list", &framework::HeterWrapper::SetXpuList) + .def("start_xpu_service", &framework::HeterWrapper::StartXpuService) + .def("end_pass", &framework::HeterWrapper::EndPass) + .def("stop_xpu_service", &framework::HeterWrapper::StopXpuService); +} // end HeterWrapper +} // end namespace pybind +} // end namespace paddle diff --git a/paddle/fluid/pybind/heter_wrapper_py.h b/paddle/fluid/pybind/heter_wrapper_py.h new file mode 100644 index 00000000000000..7f6e866893af1f --- /dev/null +++ b/paddle/fluid/pybind/heter_wrapper_py.h @@ -0,0 +1,28 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" + +namespace py = pybind11; + +namespace paddle { +namespace pybind { + +void BindHeterWrapper(py::module* m); + +} // namespace pybind +} // namespace paddle diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index a5c99aa6fce586..a3f9ccd6e74b45 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -63,6 +63,7 @@ limitations under the License. */ #include "paddle/fluid/pybind/data_set_py.h" #include "paddle/fluid/pybind/exception.h" #include "paddle/fluid/pybind/fleet_wrapper_py.h" +#include "paddle/fluid/pybind/heter_wrapper_py.h" #include "paddle/fluid/pybind/global_value_getter_setter.h" #include "paddle/fluid/pybind/gloo_wrapper_py.h" #include "paddle/fluid/pybind/imperative.h" @@ -2310,6 +2311,7 @@ All parameter, weight, gradient are variables in Paddle. .def("device_count", &ParallelExecutor::DeviceCount); BindFleetWrapper(&m); + BindHeterWrapper(&m); BindGlooWrapper(&m); BindBoxHelper(&m); #ifdef PADDLE_WITH_BOX_PS diff --git a/python/paddle/fluid/device_worker.py b/python/paddle/fluid/device_worker.py index 8f91092ff5bc00..b571381d440a7c 100644 --- a/python/paddle/fluid/device_worker.py +++ b/python/paddle/fluid/device_worker.py @@ -216,7 +216,7 @@ def _gen_worker_desc(self, trainer_desc): dense_table_set.add(i) break - trainer_desc.device_worker_name = "HeterCpuWorker" + trainer_desc.device_worker_name = opt_info.get("worker_class", "DownpourWorker") pull_thread = trainer_desc.pull_dense_param pull_thread.device_num = trainer_desc.thread_num if opt_info.get("program_id_to_worker") is None: diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index ace58854d71d2f..7d5bd6c4b30782 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -1433,6 +1433,60 @@ def infer_from_dataset(self, debug, fetch_list, fetch_info, print_period, fetch_handler) + def start_heter_trainer(self, + program=None, + scope=None, + debug=False, + fetch_list=None, + fetch_info=None, + print_period=100, + fetch_handler=None): + return self._start_heter_trainer(program, scope, False, + debug, fetch_list, fetch_info, + print_period, fetch_handler) + + def _start_heter_trainer(self, + program=None, + scope=None, + is_infer=False, + debug=False, + fetch_list=None, + fetch_info=None, + print_period=100, + fetch_handler=None): + + scope, trainer = self._prepare_trainer( + program=program, + dataset=None, + scope=scope, + thread=1, + debug=debug, + fetch_list=fetch_list, + fetch_info=fetch_info, + print_period=print_period) + + trainer._set_infer(is_infer) + trainer._gen_trainer_desc() + + self._dump_debug_info(program=program, trainer=trainer) + + trainer_instance = self._default_executor.init_for_dataset( + program.desc, trainer._desc(), scope, None) + + #if fetch_handler is not None: + # scope0 = trainer_instance.get_worker_scope(0) + # fetch_monitor = FetchHandlerMonitor(scope0, fetch_handler) + # fetch_monitor.start() + # self._default_executor.run_from_dataset(trainer_instance) + # fetch_monitor.stop() + # self._default_executor.release_trainer(trainer_instance) + #else: + + self._default_executor.run_from_dataset(trainer_instance) + #self._default_executor.release_trainer(trainer_instance) + + return trainer_instance + def train_from_dataset(self, program=None, dataset=None, diff --git a/python/paddle/fluid/incubate/fleet/base/fleet_base.py b/python/paddle/fluid/incubate/fleet/base/fleet_base.py index 09a1bac85f04ae..b327165223cdb3 100644 --- a/python/paddle/fluid/incubate/fleet/base/fleet_base.py +++ b/python/paddle/fluid/incubate/fleet/base/fleet_base.py @@ -146,6 +146,16 @@ def is_server(self): """ return self._role_maker.is_server() + def is_xpu(self): + """ + Check whether the node is an instance of server. + + Returns: + bool: True if this is a node of server, + False if not. + """ + return self._role_maker.is_xpu() + def split_files(self, files): """ split files before distributed training, diff --git a/python/paddle/fluid/incubate/fleet/base/role_maker.py b/python/paddle/fluid/incubate/fleet/base/role_maker.py index bada19abcc32d2..7f102f7bc1a953 100644 --- a/python/paddle/fluid/incubate/fleet/base/role_maker.py +++ b/python/paddle/fluid/incubate/fleet/base/role_maker.py @@ -27,6 +27,7 @@ class Role: WORKER = 1 SERVER = 2 + XPU = 3 class RoleMakerBase(object): @@ -538,7 +539,6 @@ def worker_num(self): self.generate_role() return self._trainers_num - class GeneralRoleMaker(RoleMakerBase): """ This role maker is for general use, you can set os.environ to customize: @@ -872,6 +872,147 @@ def __get_default_iface_from_interfaces(self): return intf_name return "lo" +class HeterRoleMaker(GeneralRoleMaker): + """ + This role maker is for general use, you can set os.environ to customize: + PADDLE_PSERVERS_IP_PORT_LIST : all pservers' ip:port, separated by ',' + PADDLE_TRAINER_ENDPOINTS : all trainers' ip:port, separated by ',' + TRAINING_ROLE : TRAINER or PSERVER + PADDLE_TRAINER_ID : current trainer id (only for trainer), + it is index in PADDLE_TRAINER_ENDPOINTS + PADDLE_PSERVER_ID : current pserver id (only for pserver) + it is index in PADDLE_PSERVERS_IP_PORT_LIST + """ + + def generate_role(self): + """ + generate role for general role maker + """ + if not self._role_is_generated: + eplist = os.environ["PADDLE_PSERVERS_IP_PORT_LIST"].split(",") + training_role = os.environ["TRAINING_ROLE"] + worker_endpoints = os.environ["PADDLE_TRAINER_ENDPOINTS"].split(",") + trainers_num = len(worker_endpoints) + xpu_endpoints = os.environ["PADDLE_XPU_ENDPOINTS"].split(",") + xpu_num = len(xpu_endpoints) + if training_role not in ["TRAINER", "PSERVER", "XPU"]: + raise ValueError("TRAINING_ROLE must be PSERVER or TRAINER or XPU") + if training_role == "TRAINER": + role = Role.WORKER + current_id = int(os.environ["PADDLE_TRAINER_ID"]) + self._node_type = 1 + self._cur_endpoint = worker_endpoints[current_id] + gloo = fluid.core.Gloo() + gloo.init(current_id, + len(worker_endpoints), + self._hdfs_path.rstrip("/") + "/trainer", + self._hdfs_name, self._hdfs_ugi, self._iface, + self._prefix) + self._node_type_comm = gloo + elif training_role == "XPU": + role = Role.XPU + current_id = int(os.environ["PADDLE_XPU_ID"]) + self._node_type = 2 + self._cur_endpoint = xpu_endpoints[current_id] + gloo = fluid.core.Gloo() + gloo.init(current_id, + len(xpu_endpoints), + self._hdfs_path.rstrip("/") + "/xpu", + self._hdfs_name, self._hdfs_ugi, self._iface, + self._prefix) + self._node_type_comm = gloo + elif training_role == "PSERVER": + role = Role.SERVER + if os.environ.get("PADDLE_PSERVER_ID") is not None: + current_id = int(os.environ["PADDLE_PSERVER_ID"]) + cur_endpoint = eplist[current_id] + else: + # this is for compatible with paddlecloud + cur_ip = os.environ["POD_IP"] + cur_port = os.environ["PADDLE_PORT"] + cur_endpoint = ":".join([cur_ip, cur_port]) + current_id = eplist.index(cur_endpoint) + self._node_type = 0 + self._cur_endpoint = cur_endpoint + gloo = fluid.core.Gloo() + gloo.init(current_id, + len(eplist), + self._hdfs_path.rstrip("/") + "/pserver", + self._hdfs_name, self._hdfs_ugi, self._iface, + self._prefix) + self._node_type_comm = gloo + + if training_role == "TRAINER" or training_role == "XPU": + gloo = fluid.core.Gloo() + heter_list = worker_endpoints + xpu_endpoints + gloo.init(heter_list.index(self._cur_endpoint), + len(heter_list), + self._hdfs_path.rstrip("/") + "/heter", + self._hdfs_name, self._hdfs_ugi, self._iface, + self._prefix) + self._heter_comm = gloo + + gloo = fluid.core.Gloo() + all_list = worker_endpoints + eplist + xpu_endpoints + gloo.init( + all_list.index(self._cur_endpoint), + len(all_list), + self._hdfs_path.rstrip("/") + "/all", self._hdfs_name, + self._hdfs_ugi, self._iface, self._prefix) + + self._all_comm = gloo + self._trainers_num = trainers_num + self._server_endpoints = eplist + self._role = role + self._current_id = current_id + self._rank = all_list.index(self._cur_endpoint) + self._size = len(all_list) + self._worker_endpoints = worker_endpoints + self._xpu_endpoints = xpu_endpoints + self._role_is_generated = True + + def is_xpu(self): + """ + whether current process is server + """ + if not self._role_is_generated: + self.generate_role() + return self._role == Role.XPU + + def is_first_xpu(self): + """ + whether current process is worker of rank 0 + """ + if not self._role_is_generated: + self.generate_role() + return self._role == Role.XPU and self._current_id == 0 + + def _barrier_xpu(self): + """ + barrier all workers in current distributed job + """ + if not self._role_is_generated: + self.generate_role() + if self.is_xpu(): + self._node_type_comm.barrier() + + def _barrier_heter(self): + """ + barrier all workers in current distributed job + """ + if not self._role_is_generated: + self.generate_role() + if self.is_xpu() or self.is_worker: + self._heter_comm.barrier() + + def xpu_num(self): + """ + """ + if not self._role_is_generated: + self.generate_role() + return len(self._xpu_endpoints) + + class UserDefinedRoleMaker(RoleMakerBase): """ diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py index 7dfe8f7e7d7178..e1ce2fde2ddba5 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py @@ -23,6 +23,7 @@ from paddle.fluid.incubate.fleet.base.fleet_base import Mode from paddle.fluid.incubate.fleet.base.fleet_base import DistributedOptimizer from paddle.fluid.incubate.fleet.base.role_maker import MPISymetricRoleMaker +from paddle.fluid.incubate.fleet.base.role_maker import HeterRoleMaker class PSLib(Fleet): @@ -44,6 +45,7 @@ def init(self, role_maker=None): role_maker = MPISymetricRoleMaker() super(PSLib, self).init(role_maker) self._fleet_ptr = fluid.core.Fleet() + self._heter_ptr = fluid.core.Heter() def _set_client_communication_config(self, request_timeout_ms, connect_timeout_ms, max_retry): @@ -78,23 +80,34 @@ def init_worker(self): raise Exception( "You should run DistributedOptimizer.minimize() first") # barrier_all for init_server, wait for server starts + if isinstance(self._role_maker, HeterRoleMaker): + if self._role_maker.is_xpu(): + local_endpoint = self._role_maker.get_local_endpoint() + local_endpoint = local_endpoint.split(":") + self._heter_ptr.start_xpu_service(str(local_endpoint[0]), int(local_endpoint[1])) self._role_maker._barrier_all() self.all_ips_ = self._role_maker._all_gather(self._local_ip) # worker_index * 2 is for compatible with older versions of pslib self._fleet_ptr.init_worker(self._dist_desc_str, self.all_ips_, self._role_maker._get_size(), self._role_maker.worker_index() * 2) + if isinstance(self._role_maker, HeterRoleMaker): + if self._role_maker.is_worker(): + self._heter_ptr.set_xpu_list(self._role_maker._xpu_endpoints) + self._heter_ptr.create_client2xpu_connection() # barrier_all for init_worker self._role_maker._barrier_all() # prepare for client to client communication - info = self._fleet_ptr.get_clients_info() - all_info = self._role_maker._worker_gather(info[0]) - self._fleet_ptr.gather_clients(all_info) - self._fleet_ptr.set_client2client_config( - self._client2client_request_timeout_ms, - self._client2client_connect_timeout_ms, - self._client2client_max_retry) - self._fleet_ptr.create_client2client_connection() + if self._role_maker.is_worker(): + info = self._fleet_ptr.get_clients_info() + all_info = self._role_maker._worker_gather(info[0]) + self._fleet_ptr.gather_clients(all_info) + self._fleet_ptr.set_client2client_config( + self._client2client_request_timeout_ms, + self._client2client_connect_timeout_ms, + self._client2client_max_retry) + self._fleet_ptr.create_client2client_connection() + # barrier for init model self._role_maker._barrier_worker() if self._role_maker.is_first_worker(): @@ -148,10 +161,16 @@ def init_server(self, model_dir=None, **kwargs): """ mode = kwargs.get("mode", 0) - self._role_maker._barrier_worker() - if self._role_maker.is_first_worker(): - self._fleet_ptr.load_model(model_dir, mode) - self._role_maker._barrier_worker() + if isinstance(self._role_maker, HeterRoleMaker): + self._role_maker._barrier_xpu() + if self._role_maker.is_first_xpu(): + self._fleet_ptr.load_model(model_dir, mode) + self._role_maker._barrier_xpu() + else: + self._role_maker._barrier_worker() + if self._role_maker.is_first_worker(): + self._fleet_ptr.load_model(model_dir, mode) + self._role_maker._barrier_worker() def run_server(self): """ @@ -189,6 +208,54 @@ def run_server(self): raise Exception( "You should run DistributedOptimizer.minimize() first") + def end_pass(self, scope): + if self._role_maker.worker_index() < self._role_maker.xpu_num(): + self._heter_ptr.end_pass(scope, self._role_maker.worker_index()) + self._heter_ptr.stop_xpu_service(self._role_maker.worker_index()) + + def train_from_dataset(self, + executor, + program=None, + dataset=None, + scope=None, + thread=0, + debug=False, + fetch_list=None, + fetch_info=None, + print_period=100, + fetch_handler=None): + """ + + """ + + if self._role_maker.is_worker(): + self._role_maker._barrier_heter() + executor.train_from_dataset(program, dataset, scope, thread, + debug, fetch_list, fetch_info, + print_period, fetch_handler) + + def start_heter_trainer(self, + executor, + program=None, + scope=None, + debug=False, + fetch_list=None, + fetch_info=None, + print_period=100, + fetch_handler=None): + """ + + """ + + trainer_instance = executor.start_heter_trainer(program, scope, + debug, fetch_list, fetch_info, + print_period, fetch_handler) + if self._role_maker.is_xpu(): + print("barrier heter") + self._role_maker._barrier_heter() + print("barrier heter") + executor._default_executor.release_trainer(trainer_instance) + def stop_worker(self): """ stop(): will be called after a user finishes his/her training task. Fleet instance will be @@ -201,6 +268,7 @@ def stop_worker(self): self._role_maker._barrier_worker() if self._role_maker.is_first_worker(): self._fleet_ptr.stop_server() + self._heter_ptr.stop_xpu_service() self._role_maker._barrier_worker() self._role_maker._barrier_all() self._role_maker._finalize() diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py index 11d56e84913a6c..413a3acf00e2d8 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py @@ -470,10 +470,10 @@ def _minimize(self, strategy.get("scale_datanorm", -1) }) - program_configs[program_id]["pull_dense"].extend( - [dense_table_index]) - program_configs[program_id]["push_dense"].extend( - [dense_table_index]) + program_configs[program_id]["pull_dense"].extend( + [dense_table_index]) + program_configs[program_id]["push_dense"].extend( + [dense_table_index]) dense_table_index += 1 # Todo(guru4elephant): figure out how to support more sparse parameters @@ -509,13 +509,14 @@ def _minimize(self, opt_info = {} opt_info["program_id_to_worker"] = prog_id_to_worker opt_info["program_configs"] = program_configs - opt_info["trainer"] = "DistMultiTrainer" + opt_info["trainer"] = strategy.get("trainer", "DistMultiTrainer") opt_info["device_worker"] = strategy.get("device_worker", "DownpourSGD") opt_info["optimizer"] = "DownpourSGD" opt_info["fleet_desc"] = ps_param opt_info["worker_skipped_ops"] = worker_skipped_ops opt_info["use_cvm"] = strategy.get("use_cvm", False) opt_info["no_cvm"] = strategy.get("no_cvm", False) + opt_info["worker_class"] = strategy.get("worker_class", "DownpourWorker") opt_info["stat_var_names"] = strategy.get("stat_var_names", []) opt_info["local_tables"] = strategy.get("local_tables", []) opt_info["async_tables"] = strategy.get("async_tables", []) @@ -529,6 +530,7 @@ def _minimize(self, opt_info["dump_file_num"] = strategy.get("dump_file_num", 16) opt_info["dump_fields_path"] = strategy.get("dump_fields_path", "") opt_info["dump_param"] = strategy.get("dump_param", []) + opt_info["worker_places"] = strategy.get("worker_places", []) if server._server.downpour_server_param.downpour_table_param[ 0].accessor.accessor_class in [ "DownpourCtrAccessor", "DownpourCtrDoubleAccessor" diff --git a/python/paddle/fluid/trainer_desc.py b/python/paddle/fluid/trainer_desc.py index 4a27ea3fd8871d..47f588975e9ee7 100644 --- a/python/paddle/fluid/trainer_desc.py +++ b/python/paddle/fluid/trainer_desc.py @@ -15,7 +15,7 @@ import sys from os import path -__all__ = ['TrainerDesc', 'MultiTrainer', 'DistMultiTrainer', 'PipelineTrainer'] +__all__ = ['TrainerDesc', 'MultiTrainer', 'DistMultiTrainer', 'PipelineTrainer', 'HeterXpuTrainer'] class TrainerDesc(object): @@ -107,7 +107,11 @@ def _set_dump_converter(self, converter): def _set_dump_param(self, dump_param): for param in dump_param: self.proto_desc.dump_param.append(param) - + + def _set_worker_places(self, worker_places): + for place in worker_places: + self.proto_desc.worker_places.append(place) + def _set_thread_barrier(self, thread_barrier): self.proto_desc.thread_barrier = thread_barrier @@ -258,6 +262,30 @@ def _gen_trainer_desc(self): self._device_worker._gen_worker_desc(self.proto_desc) +class HeterXpuTrainer(TrainerDesc): + """ + Implement of HeterXpuTrainer. + It's for Distributed training. + """ + + def __init__(self): + super(HeterXpuTrainer, self).__init__() + pass + + def _set_program(self, program): + super(HeterXpuTrainer, self)._set_program(program) + self._program = program + + def _gen_trainer_desc(self): + super(HeterXpuTrainer, self)._gen_trainer_desc() + self.proto_desc.class_name = "HeterXpuTrainer" + if self._program == None: + raise RuntimeError("None Program") + self._device_worker._set_infer(self._infer) + self._device_worker._set_program(self._program) + self._device_worker._gen_worker_desc(self.proto_desc) + + class PipelineTrainer(TrainerDesc): """ Implement of PipelineTrainer. diff --git a/python/paddle/fluid/trainer_factory.py b/python/paddle/fluid/trainer_factory.py index 0e071251bb2cd3..1619212ad87708 100644 --- a/python/paddle/fluid/trainer_factory.py +++ b/python/paddle/fluid/trainer_factory.py @@ -22,7 +22,7 @@ logging.basicConfig(level=logging.INFO, format=FORMAT) local_logger = logging.getLogger(__name__) -from .trainer_desc import MultiTrainer, DistMultiTrainer, PipelineTrainer +from .trainer_desc import MultiTrainer, DistMultiTrainer, PipelineTrainer, HeterXpuTrainer from .device_worker import Hogwild, DownpourSGD, Section, DownpourSGDOPT from .framework import Variable from multiprocessing import Process, Manager @@ -72,6 +72,8 @@ def _create_trainer(self, opt_info=None): trainer._set_dump_converter(opt_info["dump_converter"]) if opt_info.get("dump_param") is not None: trainer._set_dump_param(opt_info["dump_param"]) + if opt_info.get("worker_places") is not None: + trainer._set_worker_places(opt_info["worker_places"]) if "fleet_desc" in opt_info: device_worker._set_fleet_desc(opt_info["fleet_desc"]) From d125c60f7bca743d07a3112b5d3cf30a7b10aeeb Mon Sep 17 00:00:00 2001 From: xjqbest <173596896@qq.com> Date: Tue, 9 Jun 2020 14:35:31 +0800 Subject: [PATCH 3/8] multi phase --- paddle/fluid/framework/device_worker.h | 1 + paddle/fluid/framework/fleet/heter_wrapper.cc | 18 +-- paddle/fluid/framework/fleet/heter_wrapper.h | 2 +- paddle/fluid/framework/hetercpu_worker.cc | 27 +++- paddle/fluid/framework/heterxpu_trainer.cc | 22 +++- paddle/fluid/framework/trainer.h | 1 + paddle/fluid/framework/trainer_desc.proto | 15 +++ python/paddle/fluid/executor.py | 10 +- .../fluid/incubate/fleet/utils/fleet_util.py | 117 ++++++++++++++++++ python/paddle/fluid/trainer_desc.py | 36 ++++++ 10 files changed, 229 insertions(+), 20 deletions(-) diff --git a/paddle/fluid/framework/device_worker.h b/paddle/fluid/framework/device_worker.h index 3b5f57b9ade1ce..b31fa2cc172701 100644 --- a/paddle/fluid/framework/device_worker.h +++ b/paddle/fluid/framework/device_worker.h @@ -175,6 +175,7 @@ class DeviceWorker { FetchConfig fetch_config_; bool use_cvm_; bool no_cvm_; + TrainerDesc trainer_desc_; }; class CPUWorkerBase : public DeviceWorker { diff --git a/paddle/fluid/framework/fleet/heter_wrapper.cc b/paddle/fluid/framework/fleet/heter_wrapper.cc index 9636201b5ca548..6d598253334929 100644 --- a/paddle/fluid/framework/fleet/heter_wrapper.cc +++ b/paddle/fluid/framework/fleet/heter_wrapper.cc @@ -226,7 +226,7 @@ void HeterWrapper::EndPass(Scope* scope, int num) { //} } -void HeterWrapper::CallRemoteXpu(std::shared_ptr task, HeterCpuWorker* worker, int mpi_rank) { +void HeterWrapper::CallRemoteXpu(std::shared_ptr task, HeterCpuWorker* worker, int mpi_rank, std::vector& send_vars) { HeterRequest request; request.set_cmd(0); request.set_cur_batch(task->cur_batch_); @@ -247,15 +247,15 @@ void HeterWrapper::CallRemoteXpu(std::shared_ptr task, HeterCpuWorker worker->Schedule(task->taskid_); }); - std::vector varnames = {"click", "12345"}; - //varnames.push_back(send_var); - //if (send_var == "_generated_var_412") { - varnames.push_back("filter_by_instag_0.tmp_0"); - varnames.push_back("filter_by_instag_2.tmp_0"); - varnames.push_back("filter_by_instag_0.tmp_1"); - varnames.push_back("concat_1.tmp_0"); +// std::vector varnames = {"click", "12345"}; +// //varnames.push_back(send_var); +// //if (send_var == "_generated_var_412") { +// varnames.push_back("filter_by_instag_0.tmp_0"); +// varnames.push_back("filter_by_instag_2.tmp_0"); +// varnames.push_back("filter_by_instag_0.tmp_1"); +// varnames.push_back("concat_1.tmp_0"); //} - for (auto& varname : varnames) { + for (auto& varname : send_vars) { auto* req_var = request.add_vars(); SerializeToReq(varname, task->scope_, req_var); } diff --git a/paddle/fluid/framework/fleet/heter_wrapper.h b/paddle/fluid/framework/fleet/heter_wrapper.h index 1db3e18b4b8198..0ddded02c69259 100644 --- a/paddle/fluid/framework/fleet/heter_wrapper.h +++ b/paddle/fluid/framework/fleet/heter_wrapper.h @@ -67,7 +67,7 @@ class HeterWrapper { void StartXpuService(const std::string& ip, uint32_t port); - void CallRemoteXpu(std::shared_ptr task, HeterCpuWorker* worker, int mpi_rank); + void CallRemoteXpu(std::shared_ptr task, HeterCpuWorker* worker, int mpi_rank, std::vector& send_vars); void CallRemoteXpuSync(std::shared_ptr task, HeterCpuWorker* worker); diff --git a/paddle/fluid/framework/hetercpu_worker.cc b/paddle/fluid/framework/hetercpu_worker.cc index 5b373b6a75ab6a..b601262d37ccdd 100644 --- a/paddle/fluid/framework/hetercpu_worker.cc +++ b/paddle/fluid/framework/hetercpu_worker.cc @@ -62,7 +62,6 @@ void HeterTask::PackTask(Scope* thread_scope, int taskid, DataFeed* reader, int void HeterCpuWorker::GetXpuOpIndex() { xpu_begin_op_index_ = xpu_end_op_index_ = -1; for (size_t i = 0; i < ops_.size(); ++i) { - //if (!first && ops_[i]->Type() == "mul") { // first = 1; // xpu_begin_op_index_ = i; // auto& in_map = ops_[i]->Inputs(); @@ -86,6 +85,8 @@ void HeterCpuWorker::GetXpuOpIndex() { // // } // //} //} + +/* auto& out_map = ops_[i]->Outputs(); { @@ -121,10 +122,19 @@ void HeterCpuWorker::GetXpuOpIndex() { } } } + if (xpu_end_op_index_ == -1) { xpu_end_op_index_ = ops_.size() - 1; - } + }*/ + xpu_begin_op_index_ = trainer_desc_.xpu_start_idx(); + xpu_end_op_index_ = trainer_desc_.xpu_end_idx(); VLOG(0) << "xpu begin: " << xpu_begin_op_index_ << " xpu end: " << xpu_end_op_index_; + //CHECK(xpu_begin_op_index_ == trainer_desc_.xpu_start_idx()); + // CHECK(xpu_end_op_index_ == trainer_desc_.xpu_end_idx()); + // CHECK(trainer_desc_.op_run_start_idx() == 0); + // CHECK(trainer_desc_.op_run_end_idx() == xpu_begin_op_index_ - 1); + // CHECK(trainer_desc_.op_run_end_start_idx() == xpu_end_op_index_ + 1); + // CHECK(trainer_desc_.op_run_end_end_idx() == ops_.size() - 1); } void HeterCpuWorker::Schedule(int taskid) { @@ -145,6 +155,7 @@ void HeterCpuWorker::JumpContext(std::shared_ptr task) { void HeterCpuWorker::Initialize(const TrainerDesc& desc) { param_ = desc.downpour_param(); mpi_rank_ = desc.mpi_rank(); + trainer_desc_ = desc; for (int i = 0; i < param_.sparse_table_size(); ++i) { uint64_t table_id = static_cast(param_.sparse_table(i).table_id()); @@ -804,7 +815,11 @@ void HeterCpuWorker::TrainFilesWithProfiler() { else if (task->state_ == XPU) { timeline.Start(); VLOG(3) << "call remote xpu taskid = " << task->taskid_; - heter_ptr_->CallRemoteXpu(task, this, mpi_rank_); + std::vector send_var_list; + for (int i = 0; i < trainer_desc_.xpu_recv_list_size(); ++i) { + send_var_list.push_back(trainer_desc_.xpu_recv_list(i)); + } + heter_ptr_->CallRemoteXpu(task, this, mpi_rank_, send_var_list); task->Update(); JumpContext(task); timeline.Pause(); @@ -1071,7 +1086,11 @@ void HeterCpuWorker::TrainFiles() { } else if (task->state_ == XPU) { VLOG(3) << "call remote xpu taskid = " << task->taskid_; - heter_ptr_->CallRemoteXpu(task, this, mpi_rank_); + std::vector send_var_list; + for (int i = 0; i < trainer_desc_.xpu_recv_list_size(); ++i) { + send_var_list.push_back(trainer_desc_.xpu_recv_list(i)); + } + heter_ptr_->CallRemoteXpu(task, this, mpi_rank_, send_var_list); task->Update(); JumpContext(task); break; diff --git a/paddle/fluid/framework/heterxpu_trainer.cc b/paddle/fluid/framework/heterxpu_trainer.cc index 63142b5b9ab8fd..e03a7728cdd4a4 100644 --- a/paddle/fluid/framework/heterxpu_trainer.cc +++ b/paddle/fluid/framework/heterxpu_trainer.cc @@ -100,7 +100,7 @@ void HeterXpuTrainer::Initialize(const TrainerDesc &trainer_desc, // copy_streams_.push_back(stream); // places_.push_back(place); //} - + trainer_desc_ = trainer_desc; } void HeterXpuTrainer::CreateThreadParam(const ProgramDesc& program, int num) { @@ -203,6 +203,7 @@ void HeterXpuTrainer::InitOtherEnv(const ProgramDesc &main_program) { } xpu_begin_op_index_ = xpu_end_op_index_ = -1; +/* for (size_t i = 0; i < ops_.size(); ++i) { //if (!first && ops_[i]->Type() == "mul") { // first = 1; @@ -266,8 +267,12 @@ void HeterXpuTrainer::InitOtherEnv(const ProgramDesc &main_program) { if (xpu_end_op_index_ == -1) { xpu_end_op_index_ = ops_.size() - 1; } - + */ + xpu_begin_op_index_ = trainer_desc_.xpu_start_idx; + xpu_end_op_index_ = trainer_desc_.xpu_end_idx; VLOG(0) << "xpu begin: " << xpu_begin_op_index_ << " xpu end: " << xpu_end_op_index_; + //CHECK(xpu_begin_op_index_ == 0); + //CHECK(xpu_end_op_index_ = ops_.size() - 1); VLOG(3) << "init other env done."; } @@ -419,10 +424,17 @@ int HeterXpuTrainer::RunTask(const HeterRequest* request, HeterResponse* respons bthread_yield(); } - std::string varname = "concat_1.tmp_0@GRAD"; + for (int i = 0; i < trainer_desc_.xpu_send_list_size(); ++i) { + string& varname = trainer_desc_.xpu_send_list(i); + //CHECK(varname == "concat_1.tmp_0@GRAD"); + auto* res_var = response->add_vars(); + heter_ptr_->SerializeToReq(varname, context->scope_, res_var); + } - auto* res_var = response->add_vars(); - heter_ptr_->SerializeToReq(varname, context->scope_, res_var); + //std::string varname = "concat_1.tmp_0@GRAD"; + // + //auto* res_var = response->add_vars(); + //heter_ptr_->SerializeToReq(varname, context->scope_, res_var); for (int i = 0; i < param_.program_config(0).push_dense_table_id_size(); ++i) { diff --git a/paddle/fluid/framework/trainer.h b/paddle/fluid/framework/trainer.h index 7e6f227347d4ec..e9fc88d58d0c82 100644 --- a/paddle/fluid/framework/trainer.h +++ b/paddle/fluid/framework/trainer.h @@ -59,6 +59,7 @@ class TrainerBase { Scope* root_scope_; bool debug_; Dataset* dataset_ptr_; + TrainerDesc trainer_desc_; }; // general trainer for async execution diff --git a/paddle/fluid/framework/trainer_desc.proto b/paddle/fluid/framework/trainer_desc.proto index 420cab8ca8e1ca..245e2d9473f874 100644 --- a/paddle/fluid/framework/trainer_desc.proto +++ b/paddle/fluid/framework/trainer_desc.proto @@ -50,6 +50,21 @@ message TrainerDesc { optional bool thread_barrier = 22; repeated string loss_names = 23; repeated int32 worker_places = 24; + +// repeated string op_run_send_list = 25; +// repeated string op_run_recv_list = 26; +// repeated string op_run_start_idx = 27; +// repeated string op_run_end_idx = 28; + + repeated string xpu_send_list = 25; + repeated string xpu_recv_list = 26; + repeated string xpu_start_idx = 27; + repeated string xpu_end_idx = 28; + +// repeated string op_run_end_send_list = 33; +// repeated string op_run_end_recv_list = 34; +// repeated string op_run_end_start_idx = 35; +// repeated string op_run_end_end_idx = 36; // device worker parameters optional HogwildWorkerParameter hogwild_param = 101; diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index 7d5bd6c4b30782..ddb4844763bb5b 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -1276,6 +1276,10 @@ def _prepare_trainer(self, fetch_info = [] assert len(fetch_list) == len(fetch_info) compiled = isinstance(program, compiler.CompiledProgram) + from paddle.fluid.incubate.fleet.parameter_server.pslib import fleet + fu = FleetUtil() + ret = fu.split_program_by_device(program) + #start_list, end_list, send_list, recv_list, program_list = fu.split_program_by_device(program) if not compiled: # TODO: Need a better way to distinguish and specify different execution mode if program._pipeline_opt: @@ -1284,7 +1288,11 @@ def _prepare_trainer(self, else: trainer = TrainerFactory()._create_trainer(program._fleet_opt) trainer._set_thread_barrier(program._is_distributed) - trainer._set_program(program) + if fleet.is_worker(): + trainer._set_program(program) + elif fleet.is_xpu() and ret is not None: + trainer._set_program(ret[4]) + trainer._set_heter_info(ret) else: if program._pipeline_opt: trainer = TrainerFactory()._create_trainer( diff --git a/python/paddle/fluid/incubate/fleet/utils/fleet_util.py b/python/paddle/fluid/incubate/fleet/utils/fleet_util.py index 2b46459280b614..73554e366b2f11 100644 --- a/python/paddle/fluid/incubate/fleet/utils/fleet_util.py +++ b/python/paddle/fluid/incubate/fleet/utils/fleet_util.py @@ -1615,3 +1615,120 @@ def parse_program_proto(self, prog_path, is_text, output_dir): """ program = utils.load_program(prog_path, is_text) utils.parse_program(program, output_dir) + + def split_program_by_device(self, program): + ops_list = [] + type_list = [] + pre = None + type_cpu = "cpu" + for op in program.global_block().ops: + if op.has_attr("op_device"): + if pre is None or pre != op.attr("op_device"): + ops_list.append([]) + type_list.append(op.attr("op_device") if op.attr("op_device") != "" else type_cpu) + ops_list[-1].append(op) + pre = op.attr("op_device") + l = len(type_list) + i = 0 + type_heter = None + while i < l: + while i < l and type_list[i] == type_cpu: + i += 1 + if i == l: + break + + type_heter = type_list[i] + i += 1 + start = i + valid = True + while i < l and type_list[i] != type_heter: + if type_list[i] != type_cpu: + valid = False + break + i += 1 + + if i == l: + break + elif not valid: + continue + + for j in range(start, i): + for op in ops_list[j]: + op._set_attr("op_device", type_heter) + type_list[j] = type_heter + j += 1 + + pre = None + merged_ops_list = [] + merged_type_list = [] + for i in range(l): + if pre is None or pre != type_list[i]: + merged_ops_list.append([]) + merged_type_list.append(type_list[i]) + merged_ops_list[-1].extend(ops_list[i]) + pre = type_list[i] + + data_vars = set() + for k in program.global_block().vars: + var = program.global_block().var(k) + if not var.persistable: + data_vars.add(var.name) + + l = len(merged_ops_list) + inputs_pre = set() + outputs_pre = set() + in_from_pre = [[] for i in range(l)] + for i in range(l): + inputs = set() + outputs = set() + for op in merged_ops_list[i]: + for input in op.input_names: + for tmp in op.input(input): + if tmp not in outputs: + inputs.add(tmp) + for output in op.output_names: + for tmp in op.output(output): + outputs.add(tmp) + if i == 0: + in_from_pre[i] = [] + elif i == 1: + in_from_pre[i] = (outputs_pre | data_vars) & inputs + else: + in_from_pre[i] = outputs_pre & inputs + inputs_pre = copy.deepcopy(inputs) + outputs_pre = copy.deepcopy(outputs) + + l = len(in_from_pre) + start_list = [] + end_list = [] + send_list = [[] for i in range(l)] + sum = 0 + program_list = [] + for i in range(l): + start_list.append(sum) + end_list.append(sum + len(merged_ops_list[i]) - 1) + sum += len(merged_ops_list[i]) + if i < l - 1: + send_list[i].extend(list(in_from_pre[i + 1])) + prog = program.clone() + if merged_type_list[i] != type_cpu: + prog = prog._prune_with_input(list(in_from_pre[i]), list(send_list[i])) + program_list.append(prog) + else: + program_list.append(prog) + recv_list = in_from_pre + found = False + heter_index = None + for i in range(len(merged_type_list)): + t = merged_type_list[i] + if t != type_cpu: + if found: + print("only one region of program can be heter") + found = True + heter_index = i + if heter_index is None: + print("warning: non heter program") + return None + else: + return [start_list[heter_index], end_list[heter_index], send_list[heter_index], \ + recv_list[index], program_list[heter_index]] diff --git a/python/paddle/fluid/trainer_desc.py b/python/paddle/fluid/trainer_desc.py index 47f588975e9ee7..67b98c04a58031 100644 --- a/python/paddle/fluid/trainer_desc.py +++ b/python/paddle/fluid/trainer_desc.py @@ -15,6 +15,7 @@ import sys from os import path +from paddle.fluid.incubate.fleet.utils.fleet_util import FleetUtil __all__ = ['TrainerDesc', 'MultiTrainer', 'DistMultiTrainer', 'PipelineTrainer', 'HeterXpuTrainer'] @@ -45,6 +46,41 @@ def __init__(self): self._program = None self._infer = False + def _set_heter_info(self, ret): + #ret = = fu.split_program_by_device(program) + #start_list, end_list, send_list, recv_list, program_list = fu.split_program_by_device(program) + #if len(start_list) != 3: + # print("start_list len=", len(start_list), " will not set heter info") + # return + #for i in start_list[0]: + # self.proto_desc.op_run_start_idx.append(i) + #for i in end_list[0]: + # self.proto_desc.op_run_end_idx.append(i) + #for i in send_list[0]: + # self.proto_desc.op_run_send_list.append(i) + #for i in recv_list[0]: + # self.proto_desc.op_run_recv_list.append(i) + if ret is None: + return + for i in ret[0]: # start_list[1]: + self.proto_desc.xpu_start_idx.append(i) + for i in ret[1]: #end_list[1]: + self.proto_desc.o_end_idx.append(i) + for i in ret[2]: #send_list[1]: + self.proto_desc.op_run_send_list.append(i) + for i in ret[3]: # recv_list[1]: + self.proto_desc.op_run_recv_list.append(i) + + #for i in start_list[2]: + # self.proto_desc.op_run_end_start_idx.append(i) + #for i in end_list[2]: + # self.proto_desc.op_run_end_idx.append(i) + #for i in send_list[2]: + # self.proto_desc.op_run_end_send_list.append(i) + #for i in recv_list[2]: + # self.proto_desc.op_run_end_recv_list.append(i) + + def _set_fetch_var_and_info(self, fetch_vars, fetch_info, print_period): for i, v in enumerate(fetch_vars): self.proto_desc.fetch_config.fetch_var_names.extend([v.name]) From bf94bc9ba9b708d2e43373ad2d46d6be4125bbac Mon Sep 17 00:00:00 2001 From: xjqbest <173596896@qq.com> Date: Tue, 9 Jun 2020 14:44:01 +0800 Subject: [PATCH 4/8] multi phase, temporily remove prune program --- python/paddle/fluid/executor.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index ddb4844763bb5b..91c528e42e731c 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -1288,10 +1288,11 @@ def _prepare_trainer(self, else: trainer = TrainerFactory()._create_trainer(program._fleet_opt) trainer._set_thread_barrier(program._is_distributed) - if fleet.is_worker(): - trainer._set_program(program) - elif fleet.is_xpu() and ret is not None: - trainer._set_program(ret[4]) + #if fleet.is_worker(): + # trainer._set_program(program) + #elif fleet.is_xpu() and ret is not None: + # trainer._set_program(ret[4]) + trainer._set_program(program) trainer._set_heter_info(ret) else: if program._pipeline_opt: From 4251ef8b6029f30d0064a68b9cfc787461acb587 Mon Sep 17 00:00:00 2001 From: xjqbest <173596896@qq.com> Date: Tue, 9 Jun 2020 15:36:28 +0800 Subject: [PATCH 5/8] multi phase, fix compile error --- paddle/fluid/framework/hetercpu_worker.cc | 4 ++-- paddle/fluid/framework/trainer_desc.proto | 7 +++++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/framework/hetercpu_worker.cc b/paddle/fluid/framework/hetercpu_worker.cc index b601262d37ccdd..504734e4c3f2b6 100644 --- a/paddle/fluid/framework/hetercpu_worker.cc +++ b/paddle/fluid/framework/hetercpu_worker.cc @@ -60,8 +60,8 @@ void HeterTask::PackTask(Scope* thread_scope, int taskid, DataFeed* reader, int } void HeterCpuWorker::GetXpuOpIndex() { - xpu_begin_op_index_ = xpu_end_op_index_ = -1; - for (size_t i = 0; i < ops_.size(); ++i) { +// xpu_begin_op_index_ = xpu_end_op_index_ = -1; +// for (size_t i = 0; i < ops_.size(); ++i) { // first = 1; // xpu_begin_op_index_ = i; // auto& in_map = ops_[i]->Inputs(); diff --git a/paddle/fluid/framework/trainer_desc.proto b/paddle/fluid/framework/trainer_desc.proto index 245e2d9473f874..c6f3c476be52c5 100644 --- a/paddle/fluid/framework/trainer_desc.proto +++ b/paddle/fluid/framework/trainer_desc.proto @@ -58,8 +58,11 @@ message TrainerDesc { repeated string xpu_send_list = 25; repeated string xpu_recv_list = 26; - repeated string xpu_start_idx = 27; - repeated string xpu_end_idx = 28; + optional int32 xpu_start_idx = 27; + optional int32 xpu_end_idx = 28; + +//repeated string xpu_start_idx = 27; + //repeated string xpu_end_idx = 28; // repeated string op_run_end_send_list = 33; // repeated string op_run_end_recv_list = 34; From bc595c5343c78ff3ce68bde47b07e8d91f25e7ba Mon Sep 17 00:00:00 2001 From: xjqbest <173596896@qq.com> Date: Tue, 9 Jun 2020 15:45:12 +0800 Subject: [PATCH 6/8] multi phase, fix start end list --- .../paddle/fluid/incubate/fleet/utils/fleet_util.py | 2 +- python/paddle/fluid/trainer_desc.py | 11 +++++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/python/paddle/fluid/incubate/fleet/utils/fleet_util.py b/python/paddle/fluid/incubate/fleet/utils/fleet_util.py index 73554e366b2f11..b30c356b1e246d 100644 --- a/python/paddle/fluid/incubate/fleet/utils/fleet_util.py +++ b/python/paddle/fluid/incubate/fleet/utils/fleet_util.py @@ -1716,7 +1716,7 @@ def split_program_by_device(self, program): program_list.append(prog) else: program_list.append(prog) - recv_list = in_from_pre + recv_list = [list(i) for i in in_from_pre] found = False heter_index = None for i in range(len(merged_type_list)): diff --git a/python/paddle/fluid/trainer_desc.py b/python/paddle/fluid/trainer_desc.py index 67b98c04a58031..c8de38fd4fc4be 100644 --- a/python/paddle/fluid/trainer_desc.py +++ b/python/paddle/fluid/trainer_desc.py @@ -62,10 +62,13 @@ def _set_heter_info(self, ret): # self.proto_desc.op_run_recv_list.append(i) if ret is None: return - for i in ret[0]: # start_list[1]: - self.proto_desc.xpu_start_idx.append(i) - for i in ret[1]: #end_list[1]: - self.proto_desc.o_end_idx.append(i) + #for i in ret[0]: # start_list[1]: + # self.proto_desc.xpu_start_idx.append(i) + self.proto_desc.xpu_start_idx = ret[0] + + #for i in ret[1]: #end_list[1]: + # self.proto_desc.o_end_idx.append(i) + self.proto_desc.xpu_end_idx = ret[1] for i in ret[2]: #send_list[1]: self.proto_desc.op_run_send_list.append(i) for i in ret[3]: # recv_list[1]: From d28065b4c1d797fc7382dc604fb2acc229a1617d Mon Sep 17 00:00:00 2001 From: xjqbest <173596896@qq.com> Date: Tue, 9 Jun 2020 17:12:40 +0800 Subject: [PATCH 7/8] multi phase, fix import FleetUtil --- python/paddle/fluid/trainer_desc.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/paddle/fluid/trainer_desc.py b/python/paddle/fluid/trainer_desc.py index c8de38fd4fc4be..7ef6672e29d9c8 100644 --- a/python/paddle/fluid/trainer_desc.py +++ b/python/paddle/fluid/trainer_desc.py @@ -15,7 +15,6 @@ import sys from os import path -from paddle.fluid.incubate.fleet.utils.fleet_util import FleetUtil __all__ = ['TrainerDesc', 'MultiTrainer', 'DistMultiTrainer', 'PipelineTrainer', 'HeterXpuTrainer'] From 9a2364fa121c4b9d1888db9857290d774b51242d Mon Sep 17 00:00:00 2001 From: xjqbest <173596896@qq.com> Date: Tue, 9 Jun 2020 20:47:12 +0800 Subject: [PATCH 8/8] heter split, fix import --- python/paddle/fluid/executor.py | 1 + python/paddle/fluid/incubate/fleet/utils/fleet_util.py | 3 ++- python/paddle/fluid/trainer_desc.py | 4 ++-- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index 91c528e42e731c..2cc81038561256 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -1277,6 +1277,7 @@ def _prepare_trainer(self, assert len(fetch_list) == len(fetch_info) compiled = isinstance(program, compiler.CompiledProgram) from paddle.fluid.incubate.fleet.parameter_server.pslib import fleet + from paddle.fluid.incubate.fleet.utils.fleet_util import FleetUtil fu = FleetUtil() ret = fu.split_program_by_device(program) #start_list, end_list, send_list, recv_list, program_list = fu.split_program_by_device(program) diff --git a/python/paddle/fluid/incubate/fleet/utils/fleet_util.py b/python/paddle/fluid/incubate/fleet/utils/fleet_util.py index b30c356b1e246d..8ec06d33034c3e 100644 --- a/python/paddle/fluid/incubate/fleet/utils/fleet_util.py +++ b/python/paddle/fluid/incubate/fleet/utils/fleet_util.py @@ -14,6 +14,7 @@ """Fleet Utils.""" import collections +import copy import json import logging import math @@ -1731,4 +1732,4 @@ def split_program_by_device(self, program): return None else: return [start_list[heter_index], end_list[heter_index], send_list[heter_index], \ - recv_list[index], program_list[heter_index]] + recv_list[heter_index], program_list[heter_index]] diff --git a/python/paddle/fluid/trainer_desc.py b/python/paddle/fluid/trainer_desc.py index 7ef6672e29d9c8..226e12ad39cd7a 100644 --- a/python/paddle/fluid/trainer_desc.py +++ b/python/paddle/fluid/trainer_desc.py @@ -69,9 +69,9 @@ def _set_heter_info(self, ret): # self.proto_desc.o_end_idx.append(i) self.proto_desc.xpu_end_idx = ret[1] for i in ret[2]: #send_list[1]: - self.proto_desc.op_run_send_list.append(i) + self.proto_desc.xpu_send_list.append(i) for i in ret[3]: # recv_list[1]: - self.proto_desc.op_run_recv_list.append(i) + self.proto_desc.xpu_recv_list.append(i) #for i in start_list[2]: # self.proto_desc.op_run_end_start_idx.append(i)