From 57b34b6b67d3415f6b53abf0ee540fd36a40c9a5 Mon Sep 17 00:00:00 2001 From: danleifeng <52735331+danleifeng@users.noreply.github.com> Date: Tue, 20 Jul 2021 18:32:05 +0800 Subject: [PATCH 1/4] psgpu:add cuda remote_streams; test=develop --- .../fluid/framework/fleet/heter_ps/hashtable.h | 3 +++ .../framework/fleet/heter_ps/hashtable_inl.h | 1 + .../framework/fleet/heter_ps/heter_comm_inl.h | 18 ++++++++++++------ .../framework/fleet/heter_ps/heter_resource.cc | 14 ++++++++------ .../framework/fleet/heter_ps/heter_resource.h | 6 +++--- paddle/fluid/framework/fleet/ps_gpu_wrapper.cu | 6 +++--- paddle/fluid/framework/fleet/ps_gpu_wrapper.h | 1 + 7 files changed, 31 insertions(+), 18 deletions(-) diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable.h b/paddle/fluid/framework/fleet/heter_ps/hashtable.h index 3782e14ad41a5e..61292660acaedb 100644 --- a/paddle/fluid/framework/fleet/heter_ps/hashtable.h +++ b/paddle/fluid/framework/fleet/heter_ps/hashtable.h @@ -24,6 +24,7 @@ limitations under the License. */ #include "paddle/fluid/distributed/table/depends/large_scale_kv.h" #endif #include "thrust/pair.h" +#include "paddle/fluid/framework/rw_lock.h" //#include "cudf/concurrent_unordered_map.cuh.h" #include "paddle/fluid/framework/fleet/heter_ps/cudf/concurrent_unordered_map.cuh.h" #ifdef PADDLE_WITH_HETERPS @@ -63,11 +64,13 @@ class HashTable { int size() { return container_->size(); } + std::unique_ptr rwlock_{nullptr}; private: TableContainer* container_; int BLOCK_SIZE_{256}; float LOAD_FACTOR{0.75f}; size_t capacity_; + }; } // end namespace framework } // end namespace paddle diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable_inl.h b/paddle/fluid/framework/fleet/heter_ps/hashtable_inl.h index 098c795fc7e1f9..9facbff1f25269 100644 --- a/paddle/fluid/framework/fleet/heter_ps/hashtable_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/hashtable_inl.h @@ -73,6 +73,7 @@ __global__ void update_kernel(Table* table, template HashTable::HashTable(size_t capacity) { container_ = new TableContainer(capacity); + rwlock_.reset(new RWLock); } template diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h index a2e09b7e08132f..02714978526bc5 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h @@ -525,12 +525,14 @@ void HeterComm::pull_sparse(int num, auto& node = path_[num][i].nodes_.back(); cudaStreamSynchronize(node.in_stream); platform::CUDADeviceGuard guard(resource_->dev_id(i)); + tables_[i]->rwlock_->RDLock(); tables_[i]->get(reinterpret_cast(node.key_storage), reinterpret_cast(node.val_storage), - h_right[i] - h_left[i] + 1, resource_->remote_stream(i)); + h_right[i] - h_left[i] + 1, resource_->remote_stream(num, i)); } for (int i = 0; i < total_gpu; ++i) { - cudaStreamSynchronize(resource_->remote_stream(i)); + cudaStreamSynchronize(resource_->remote_stream(num, i)); + tables_[i]->rwlock_->UNLock(); } walk_to_src(num, total_gpu, h_left, h_right, d_shard_vals_ptr); @@ -621,13 +623,15 @@ void HeterComm::push_sparse(int gpu_num, cudaStreamSynchronize(node.in_stream); platform::CUDADeviceGuard guard(resource_->dev_id(i)); + tables_[i]->rwlock_->WRLock(); tables_[i]->update(reinterpret_cast(node.key_storage), reinterpret_cast(node.val_storage), h_right[i] - h_left[i] + 1, sgd, - resource_->remote_stream(i)); + resource_->remote_stream(gpu_num, i)); } for (int i = 0; i < total_gpu; ++i) { - cudaStreamSynchronize(resource_->remote_stream(i)); + cudaStreamSynchronize(resource_->remote_stream(gpu_num, i)); + tables_[i]->rwlock_->UNLock(); } } @@ -641,9 +645,11 @@ void HeterComm::update_one_table( int dev_id = resource_->dev_id(gpu_num); platform::CUDADeviceGuard guard(dev_id); + tables_[gpu_num]->rwlock_->WRLock(); tables_[gpu_num]->update(d_keys, d_grads, len, sgd, - resource_->remote_stream(gpu_num)); - cudaStreamSynchronize(resource_->remote_stream(gpu_num)); + resource_->remote_stream(gpu_num, gpu_num)); + tables_[gpu_num]->rwlock_->UNLock(); + cudaStreamSynchronize(resource_->remote_stream(gpu_num, gpu_num)); } template diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_resource.cc b/paddle/fluid/framework/fleet/heter_ps/heter_resource.cc index 0f2af2a522e287..a369a612d4935d 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_resource.cc +++ b/paddle/fluid/framework/fleet/heter_ps/heter_resource.cc @@ -27,16 +27,16 @@ GPUResource::GPUResource(std::vector& dev_ids, int index) { platform::CUDADeviceGuard guard(dev_id_); local_streams_.resize(dev_ids_.size()); comm_streams_.resize(dev_ids_.size()); + remote_streams_.resize(dev_ids_.size()); for (size_t i = 0; i < dev_ids_.size(); ++i) { PADDLE_ENFORCE_CUDA_SUCCESS( cudaStreamCreateWithFlags(&local_streams_[i], cudaStreamNonBlocking)); PADDLE_ENFORCE_CUDA_SUCCESS( cudaStreamCreateWithFlags(&comm_streams_[i], cudaStreamNonBlocking)); + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaStreamCreateWithFlags(&remote_streams_[i], cudaStreamNonBlocking)); } - - PADDLE_ENFORCE_CUDA_SUCCESS( - cudaStreamCreateWithFlags(&remote_stream_, cudaStreamNonBlocking)); } GPUResource::~GPUResource() { @@ -47,7 +47,9 @@ GPUResource::~GPUResource() { for (size_t i = 0; i < comm_streams_.size(); ++i) { PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamDestroy(comm_streams_[i])); } - PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamDestroy(remote_stream_)); + for (size_t i = 0; i < remote_streams_.size(); ++i) { + PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamDestroy(remote_streams_[i])); + } } void HeterPsResource::enable_p2p() { @@ -90,8 +92,8 @@ cudaStream_t HeterPsResource::local_stream(int gpu_num, int stream_num) { return resources_[gpu_num]->local_stream(stream_num); } -cudaStream_t HeterPsResource::remote_stream(int gpu_num) { - return resources_[gpu_num]->remote_stream(); +cudaStream_t HeterPsResource::remote_stream(int gpu_num, int stream_num) { + return resources_[gpu_num]->remote_stream(stream_num); } int HeterPsResource::dev_id(int num) { return dev_ids_[num]; } diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_resource.h b/paddle/fluid/framework/fleet/heter_ps/heter_resource.h index 7b23379994c735..7bc52e52e6887d 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_resource.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_resource.h @@ -35,13 +35,13 @@ class GPUResource { int dev_id() const { return dev_id_; } int index() const { return index_; } gpuStream_t local_stream(int num) { return local_streams_[num]; } - gpuStream_t remote_stream() { return remote_stream_; } + gpuStream_t remote_stream(int num) { return remote_streams_[num]; } gpuStream_t comm_stream(int num) { return comm_streams_[num]; } int dev_id_; int index_; std::vector dev_ids_; - gpuStream_t remote_stream_; + std::vector remote_streams_; std::vector local_streams_; std::vector comm_streams_; }; @@ -57,7 +57,7 @@ class HeterPsResource { int get_index_by_devid(int devid); int dev_id(int num); gpuStream_t local_stream(int gpu_num, int stream_num); - gpuStream_t remote_stream(int gpu_num); + gpuStream_t remote_stream(int gpu_num, int stream_num); gpuStream_t comm_stream(int gpu_num, int stream_num); std::vector> resources_; diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cu b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cu index 5ff41d818012e2..6519a514ff3b69 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cu +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cu @@ -121,7 +121,7 @@ void PSGPUWrapper::CopyForPull(const paddle::platform::Place& place, cudaMemcpy(gpu_values, values.data(), values.size() * sizeof(float*), cudaMemcpyHostToDevice); - PullCopy<<<(total_length + 512 - 1) / 512, 512, 0, stream>>>( + PullCopy<<<(total_length + 1024 - 1) / 1024, 1024, 0, stream>>>( gpu_values, total_values_gpu, gpu_len, hidden_size, slot_num, total_length, gpu_keys); cudaStreamSynchronize(stream); @@ -135,7 +135,7 @@ void PSGPUWrapper::CopyKeys(const paddle::platform::Place& place, platform::DeviceContextPool::Instance().Get( BOOST_GET_CONST(platform::CUDAPlace, place))) ->stream(); - CopyKeysKernel<<<(total_len + 512 - 1) / 512, 512, 0, stream>>>( + CopyKeysKernel<<<(total_len + 1024 - 1) / 1024, 1024, 0, stream>>>( origin_keys, total_keys, gpu_len, slot_num, total_len); cudaStreamSynchronize(stream); } @@ -173,7 +173,7 @@ void PSGPUWrapper::CopyForPush(const paddle::platform::Place& place, cudaMemcpy(d_slot_vector, slot_vector_.data(), slot_lengths_lod.size() * sizeof(int), cudaMemcpyHostToDevice); - PushCopy<<<(total_length + 512 - 1) / 512, 512, 0, stream>>>( + PushCopy<<<(total_length + 1024 - 1) / 1024, 1024, 0, stream>>>( total_grad_values_gpu, gpu_values, gpu_len, hidden_size, slot_lengths.size(), total_length, batch_size, d_slot_vector); cudaStreamSynchronize(stream); diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h index b7e8bbb3694922..5b5d697268649d 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h @@ -34,6 +34,7 @@ limitations under the License. */ #include "paddle/fluid/framework/fleet/heter_context.h" #include "paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h" #include "paddle/fluid/framework/fleet/heter_ps/heter_resource.h" +#include "paddle/fluid/distributed/thirdparty/round_robin.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/variable_helper.h" From fb6fbef6c71a832f0999d465c412817643711820 Mon Sep 17 00:00:00 2001 From: danleifeng Date: Wed, 21 Jul 2021 14:13:25 +0800 Subject: [PATCH 2/4] psgpu:add cuda remote_streams; test=develop --- paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h index 02714978526bc5..535211fe2acd7b 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h @@ -528,10 +528,10 @@ void HeterComm::pull_sparse(int num, tables_[i]->rwlock_->RDLock(); tables_[i]->get(reinterpret_cast(node.key_storage), reinterpret_cast(node.val_storage), - h_right[i] - h_left[i] + 1, resource_->remote_stream(num, i)); + h_right[i] - h_left[i] + 1, resource_->remote_stream(i, num)); } for (int i = 0; i < total_gpu; ++i) { - cudaStreamSynchronize(resource_->remote_stream(num, i)); + cudaStreamSynchronize(resource_->remote_stream(i, num)); tables_[i]->rwlock_->UNLock(); } @@ -627,10 +627,10 @@ void HeterComm::push_sparse(int gpu_num, tables_[i]->update(reinterpret_cast(node.key_storage), reinterpret_cast(node.val_storage), h_right[i] - h_left[i] + 1, sgd, - resource_->remote_stream(gpu_num, i)); + resource_->remote_stream(i, gpu_num)); } for (int i = 0; i < total_gpu; ++i) { - cudaStreamSynchronize(resource_->remote_stream(gpu_num, i)); + cudaStreamSynchronize(resource_->remote_stream(i, gpu_num)); tables_[i]->rwlock_->UNLock(); } } From cb052da6098139e31e46e7c6ffabf55c3f006415 Mon Sep 17 00:00:00 2001 From: danleifeng Date: Wed, 21 Jul 2021 14:18:59 +0800 Subject: [PATCH 3/4] psgpu:add cuda remote_streams; test=develop --- paddle/fluid/framework/fleet/ps_gpu_wrapper.h | 1 - 1 file changed, 1 deletion(-) diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h index 5b5d697268649d..b7e8bbb3694922 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h @@ -34,7 +34,6 @@ limitations under the License. */ #include "paddle/fluid/framework/fleet/heter_context.h" #include "paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h" #include "paddle/fluid/framework/fleet/heter_ps/heter_resource.h" -#include "paddle/fluid/distributed/thirdparty/round_robin.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/variable_helper.h" From f655ecabd58a1f386e09dba0eb1aa8ff816f6b28 Mon Sep 17 00:00:00 2001 From: danleifeng Date: Wed, 21 Jul 2021 20:29:18 +0800 Subject: [PATCH 4/4] fix typo;test=develop --- paddle/fluid/framework/fleet/heter_ps/hashtable.h | 8 ++++---- paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h | 3 ++- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable.h b/paddle/fluid/framework/fleet/heter_ps/hashtable.h index 61292660acaedb..646a2e97d319fb 100644 --- a/paddle/fluid/framework/fleet/heter_ps/hashtable.h +++ b/paddle/fluid/framework/fleet/heter_ps/hashtable.h @@ -23,9 +23,9 @@ limitations under the License. */ #ifdef PADDLE_WITH_PSCORE #include "paddle/fluid/distributed/table/depends/large_scale_kv.h" #endif -#include "thrust/pair.h" #include "paddle/fluid/framework/rw_lock.h" -//#include "cudf/concurrent_unordered_map.cuh.h" +#include "thrust/pair.h" +// #include "cudf/concurrent_unordered_map.cuh.h" #include "paddle/fluid/framework/fleet/heter_ps/cudf/concurrent_unordered_map.cuh.h" #ifdef PADDLE_WITH_HETERPS #include "paddle/fluid/platform/type_defs.h" @@ -64,13 +64,13 @@ class HashTable { int size() { return container_->size(); } - std::unique_ptr rwlock_{nullptr}; + std::unique_ptr rwlock_{nullptr}; + private: TableContainer* container_; int BLOCK_SIZE_{256}; float LOAD_FACTOR{0.75f}; size_t capacity_; - }; } // end namespace framework } // end namespace paddle diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h index 535211fe2acd7b..d199a39162ba6e 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h @@ -528,7 +528,8 @@ void HeterComm::pull_sparse(int num, tables_[i]->rwlock_->RDLock(); tables_[i]->get(reinterpret_cast(node.key_storage), reinterpret_cast(node.val_storage), - h_right[i] - h_left[i] + 1, resource_->remote_stream(i, num)); + h_right[i] - h_left[i] + 1, + resource_->remote_stream(i, num)); } for (int i = 0; i < total_gpu; ++i) { cudaStreamSynchronize(resource_->remote_stream(i, num));