Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/framework/fleet/fleet_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ limitations under the License. */
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"

#include "glog/logging.h"
#include "paddle/fluid/framework/op_registry.h"

namespace paddle {
namespace framework {
Expand Down
68 changes: 68 additions & 0 deletions paddle/fluid/framework/fleet/heter_ps/heter_comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/dynload/nccl.h"
#include "paddle/fluid/platform/place.h"
#include "thrust/pair.h"

Expand Down Expand Up @@ -68,7 +69,30 @@ class HeterComm {
void push_sparse(int num, KeyType* d_keys, GradType* d_grads, size_t len,
Sgd& sgd);

template <typename Sgd>
void push_sparse_multi_node(int num, KeyType* d_keys, GradType* d_grads,
size_t len, Sgd& sgd);

template <typename Sgd>
void update_one_table(int num, KeyType* d_keys, GradType* d_grads, size_t len,
Sgd& sgd);

int gather_one_node_grad(int num, KeyType* d_keys, GradType* d_grads,
int len);

int gather_multi_node_grad(int num, KeyType* d_keys, GradType* d_grads,
int len);

int log2i(int x);

void set_nccl_comm_and_size(const std::vector<ncclComm_t>& inner_comms,
const std::vector<ncclComm_t>& inter_comms,
int comm_size) {
nccl_inner_comms_ = inner_comms;
nccl_inter_comms_ = inter_comms;
node_size_ = comm_size;
}

bool need_transfer(int send_id, int receive_id) {
return ((send_id / 4 != receive_id / 4) && (send_id + 4) % 8 != receive_id);
}
Expand All @@ -94,6 +118,44 @@ class HeterComm {
std::vector<Node> nodes_;
};

struct LocalStorage {
LocalStorage() {}
void init(int size, int dev_id) {
place_ = platform::CUDAPlace(dev_id);
alloc(size, true);
}

void alloc(int size, bool force = false) {
if (force || size > all_keys_mem->size()) {
all_keys_mem.reset();
all_grads_mem.reset();
all_keys_mem = memory::AllocShared(place_, size * sizeof(KeyType));
all_grads_mem = memory::AllocShared(place_, size * sizeof(GradType));
all_keys = reinterpret_cast<KeyType*>(all_keys_mem->ptr());
all_grads = reinterpret_cast<GradType*>(all_grads_mem->ptr());
}
if (force || size > local_keys_mem->size()) {
local_keys_mem.reset();
local_grads_mem.reset();
local_keys_mem = memory::AllocShared(place_, size * sizeof(KeyType));
local_grads_mem = memory::AllocShared(place_, size * sizeof(GradType));
local_keys = reinterpret_cast<KeyType*>(local_keys_mem->ptr());
local_grads = reinterpret_cast<GradType*>(local_grads_mem->ptr());
}
}

platform::CUDAPlace place_;
std::shared_ptr<memory::Allocation> all_keys_mem;
std::shared_ptr<memory::Allocation> all_grads_mem;
KeyType* all_keys;
GradType* all_grads;

std::shared_ptr<memory::Allocation> local_keys_mem;
std::shared_ptr<memory::Allocation> local_grads_mem;
KeyType* local_keys;
GradType* local_grads;
};

void init_path();
void create_storage(
int start_index, int end_index, int keylen, int vallen,
Expand All @@ -111,6 +173,12 @@ class HeterComm {
CustomGradMerger merger_;
int topo_aware_{1};
std::vector<std::vector<Path>> path_;
std::vector<LocalStorage> storage_;
int feanum_{1800 * 2048};
int multi_node_{1};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

写成可配置的形式

std::vector<ncclComm_t> nccl_inner_comms_;
std::vector<ncclComm_t> nccl_inter_comms_;
int node_size_;
};

} // end namespace framework
Expand Down
184 changes: 184 additions & 0 deletions paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,14 @@ template <typename KeyType, typename ValType, typename GradType>
HeterComm<KeyType, ValType, GradType>::HeterComm(
size_t capacity, std::shared_ptr<HeterPsResource> resource) {
resource_ = resource;
storage_.resize(resource_->total_gpu());
for (int i = 0; i < resource_->total_gpu(); ++i) {
platform::CUDADeviceGuard guard(resource_->dev_id(i));
auto table = new Table(capacity / load_factor_);
tables_.push_back(table);
if (multi_node_) {
storage_[i].init(feanum_, resource_->dev_id(i));
}
}
init_path();
}
Expand Down Expand Up @@ -595,6 +599,186 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int gpu_num,
}
}

template <typename KeyType, typename ValType, typename GradType>
template <typename Sgd>
void HeterComm<KeyType, ValType, GradType>::update_one_table(
int gpu_num, KeyType* d_keys, GradType* d_grads, size_t len, Sgd& sgd) {
if (len == 0) {
return;
}

int dev_id = resource_->dev_id(gpu_num);
platform::CUDADeviceGuard guard(dev_id);
tables_[gpu_num]->update(d_keys, d_grads, len, sgd,
resource_->remote_stream(gpu_num));
cudaStreamSynchronize(resource_->remote_stream(gpu_num));
}

template <typename KeyType, typename ValType, typename GradType>
template <typename Sgd>
void HeterComm<KeyType, ValType, GradType>::push_sparse_multi_node(
int gpu_num, KeyType* d_keys, GradType* d_grads, size_t len, Sgd& sgd) {
if (len == 0) {
return;
}

int uniq_len = len;
merge_grad(gpu_num, d_keys, d_grads, len, uniq_len);

uniq_len = gather_one_node_grad(gpu_num, d_keys, d_grads, uniq_len);

uniq_len = gather_multi_node_grad(gpu_num, storage_[gpu_num].local_keys,
storage_[gpu_num].local_grads, uniq_len);

update_one_table(gpu_num, storage_[gpu_num].local_keys,
storage_[gpu_num].local_grads, uniq_len, sgd);
}

template <typename KeyType, typename ValType, typename GradType>
int HeterComm<KeyType, ValType, GradType>::gather_one_node_grad(
int gpu_num, KeyType* d_keys, GradType* d_grads, int len) {
int total_gpu = resource_->total_gpu();
int dev_id = resource_->dev_id(gpu_num);
auto& storage = storage_[gpu_num];
platform::CUDAPlace place = platform::CUDAPlace(dev_id);
platform::CUDADeviceGuard guard(dev_id);
auto stream = resource_->local_stream(gpu_num, 0);
int max_size = 0;

ncclComm_t nccl_inner_comm = nccl_inner_comms_[gpu_num];
// alloc for size
int h_node_len[total_gpu];
auto d_node_len_mem = memory::AllocShared(place, total_gpu * sizeof(int));
int* d_node_len = reinterpret_cast<int*>(d_node_len_mem->ptr());
h_node_len[gpu_num] = len;

cudaMemcpy(d_node_len + gpu_num, h_node_len + gpu_num, sizeof(int),
cudaMemcpyHostToDevice);

// allgather grad len
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupStart());
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather(
(const void*)(d_node_len + gpu_num), (void*)d_node_len, 1, ncclInt,
nccl_inner_comm, stream));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupEnd());
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
cudaMemcpy(h_node_len, d_node_len, sizeof(int) * total_gpu,
cudaMemcpyDeviceToHost);

for (int i = 0; i < total_gpu; ++i) {
if (h_node_len[i] > max_size) {
max_size = h_node_len[i];
}
}
storage.alloc(max_size * total_gpu);

// allgather keys and grads
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupStart());
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather(
d_keys, storage.all_keys, max_size, ncclUint64, nccl_inner_comm, stream));

PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather(
d_grads, storage.all_grads, max_size * sizeof(GradType), ncclUint8,
nccl_inner_comm, stream));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupEnd());
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));

int h_left[total_gpu];
int h_right[total_gpu];
auto d_left = memory::AllocShared(place, total_gpu * sizeof(int));
auto d_right = memory::AllocShared(place, total_gpu * sizeof(int));
int* d_left_ptr = reinterpret_cast<int*>(d_left->ptr());
int* d_right_ptr = reinterpret_cast<int*>(d_right->ptr());

int merge_num = 0;
for (int i = 0; i < total_gpu; ++i) {
int index = i * max_size;
auto d_idx = memory::AllocShared(place, h_node_len[i] * sizeof(int));
int* d_idx_ptr = reinterpret_cast<int*>(d_idx->ptr());

cudaMemset(d_left_ptr, -1, total_gpu * sizeof(int));
cudaMemset(d_right_ptr, -1, total_gpu * sizeof(int));

split_input_to_shard(storage.all_keys + index, d_idx_ptr, h_node_len[i],
d_left_ptr, d_right_ptr, gpu_num);
cudaMemcpy(h_left, d_left_ptr, total_gpu * sizeof(int),
cudaMemcpyDeviceToHost);
cudaMemcpy(h_right, d_right_ptr, total_gpu * sizeof(int),
cudaMemcpyDeviceToHost);

int grid_size = (h_node_len[i] - 1) / block_size_ + 1;
fill_shard_grads<<<grid_size, block_size_, 0, stream>>>(
storage.local_keys + merge_num, storage.all_keys + index,
storage.local_grads + merge_num, storage.all_grads + index,
d_idx_ptr + h_left[gpu_num], h_right[gpu_num] - h_left[gpu_num] + 1);
merge_num = merge_num + h_right[gpu_num] - h_left[gpu_num] + 1;
}

int ret = merge_num;
merge_grad(gpu_num, storage.local_keys, storage.local_grads, merge_num, ret);
return ret;
}

template <typename KeyType, typename ValType, typename GradType>
int HeterComm<KeyType, ValType, GradType>::gather_multi_node_grad(
int gpu_num, KeyType* d_keys, GradType* d_grads, int len) {
int dev_id = resource_->dev_id(gpu_num);
auto& storage = storage_[gpu_num];
platform::CUDAPlace place = platform::CUDAPlace(dev_id);
platform::CUDADeviceGuard guard(dev_id);
auto stream = resource_->local_stream(gpu_num, 0);
int max_size = 0;
ncclComm_t nccl_inter_comm = nccl_inter_comms_[gpu_num];
// alloc for size
int h_node_len[node_size_];
auto d_node_len_mem = memory::AllocShared(place, node_size_ * sizeof(int));
int* d_node_len = reinterpret_cast<int*>(d_node_len_mem->ptr());
h_node_len[0] = len;

cudaMemcpy(d_node_len, h_node_len, sizeof(int), cudaMemcpyHostToDevice);

// allgather grad len
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupStart());
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather(
d_node_len, d_node_len, 1, ncclInt, nccl_inter_comm, stream));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupEnd());
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
cudaMemcpy(h_node_len, d_node_len, sizeof(int) * node_size_,
cudaMemcpyDeviceToHost);

for (int i = 0; i < node_size_; ++i) {
if (h_node_len[i] > max_size) {
max_size = h_node_len[i];
}
}
storage.alloc(max_size * node_size_);

// allgather keys and grads
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupStart());
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather(
d_keys, storage.all_keys, max_size, ncclUint64, nccl_inter_comm, stream));

PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather(
d_grads, storage.all_grads, max_size * sizeof(GradType), ncclUint8,
nccl_inter_comm, stream));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupEnd());
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));

int merge_num = 0;
for (int i = 0; i < node_size_; ++i) {
int index = i * max_size;
cudaMemcpyAsync(storage.local_keys + merge_num, storage.all_keys + index,
h_node_len[i], cudaMemcpyDefault, stream);
cudaMemcpyAsync(storage.local_grads + merge_num, storage.all_grads + index,
h_node_len[i], cudaMemcpyDefault, stream);
merge_num += h_node_len[i];
}

int ret = merge_num;
merge_grad(gpu_num, storage.local_keys, storage.local_grads, merge_num, ret);
return ret;
}

template <typename KeyType, typename ValType, typename GradType>
void HeterComm<KeyType, ValType, GradType>::end_pass() {
int total_gpu = resource_->total_gpu();
Expand Down
9 changes: 8 additions & 1 deletion paddle/fluid/framework/fleet/heter_ps/heter_ps.cu
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,14 @@ void HeterPs::show_one_table(int gpu_num) { comm_->show_one_table(gpu_num); }

void HeterPs::push_sparse(int num, FeatureKey* d_keys,
FeaturePushValue* d_grads, size_t len) {
comm_->push_sparse(num, d_keys, d_grads, len, opt_);
// comm_->push_sparse(num, d_keys, d_grads, len, opt_);
comm_->push_sparse_multi_node(num, d_keys, d_grads, len, opt_);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要加入单机多机的判断,走push_sparse 或 push_sparse_multi_node

}

void HeterPs::set_nccl_comm_and_size(const std::vector<ncclComm_t>& inner_comms,
const std::vector<ncclComm_t>& inter_comms,
int comm_size) {
comm_->set_nccl_comm_and_size(inner_comms, inter_comms, comm_size);
}

} // end namespace framework
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/framework/fleet/heter_ps/heter_ps.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ class HeterPs : public HeterPsBase {
size_t len) override;
virtual void build_ps(int num, FeatureKey* h_keys, FeatureValue* h_vals,
size_t len, size_t chunk_size, int stream_num) override;
virtual void set_nccl_comm_and_size(
const std::vector<ncclComm_t>& inner_comms,
const std::vector<ncclComm_t>& inter_comms, int comm_size) override;
virtual void end_pass() override;
virtual int get_index_by_devid(int devid) override;
virtual void show_one_table(int gpu_num) override;
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ class HeterPsBase {
virtual void build_ps(int num, FeatureKey* h_keys, FeatureValue* h_vals,
size_t len, size_t chunk_size, int stream_num) = 0;
virtual int get_index_by_devid(int devid) = 0;
virtual void set_nccl_comm_and_size(
const std::vector<ncclComm_t>& inner_comms,
const std::vector<ncclComm_t>& inter_comms, int comm_size) = 0;
virtual void end_pass() = 0;
virtual void show_one_table(int gpu_num) = 0;
virtual void push_sparse(int num, FeatureKey* d_keys,
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/fleet/heter_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ limitations under the License. */

#include "paddle/fluid/framework/fleet/heter_wrapper.h"
#ifdef PADDLE_WITH_PSLIB
#include "paddle/fluid/framework/device_worker.h"

namespace paddle {
namespace framework {
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/fleet/ps_gpu_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ void PSGPUWrapper::BuildGPUPS(uint64_t table_id, int feature_dim) {
}
std::vector<std::thread> threads(device_num);
HeterPs_ = HeterPsBase::get_instance(size_max, resource_);
HeterPs_->set_nccl_comm_and_size(inner_comms_, inter_comms_, node_size_);
auto build_func = [this, &gpu_task, &feature_keys_count](int i) {
std::cout << "building table: " << i << std::endl;
this->HeterPs_->build_ps(i, gpu_task->device_keys_[i].data(),
Expand Down
Loading