From ff5fa32b56773ab56f4719768e95f4fcd43a2255 Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Fri, 29 Apr 2022 07:58:51 +0000 Subject: [PATCH 1/6] add actual_val --- .../framework/fleet/heter_ps/gpu_graph_node.h | 13 ++++++++-- .../fleet/heter_ps/graph_gpu_ps_table_inl.h | 25 +++++++++++++++++++ 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h b/paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h index a8fde3f36bc6d8..262bddc0676c11 100644 --- a/paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h +++ b/paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h @@ -138,10 +138,12 @@ struct NeighborSampleQuery { } }; struct NeighborSampleResult { - int64_t *val; + int64_t *val, *actual_val; int *actual_sample_size, sample_size, key_size; std::shared_ptr val_mem, actual_sample_size_mem; + std::shared_ptr actual_val_mem; int64_t *get_val() { return val; } + int64_t *get_actual_val() { return actual_val; } int *get_actual_sample_size() { return actual_sample_size; } int get_sample_size() { return sample_size; } int get_key_size() { return key_size; } @@ -160,23 +162,30 @@ struct NeighborSampleResult { void display() { VLOG(0) << "in node sample result display ------------------"; int64_t *res = new int64_t[sample_size * key_size]; + int64_t *res2 = new int64_t[7]; cudaMemcpy(res, val, sample_size * key_size * sizeof(int64_t), cudaMemcpyDeviceToHost); + cudaMemcpy(res2, actual_val, 7 * sizeof(int64_t), cudaMemcpyDeviceToHost); int *ac_size = new int[key_size]; cudaMemcpy(ac_size, actual_sample_size, key_size * sizeof(int), cudaMemcpyDeviceToHost); // 3, 1, 3 + int start = 0; for (int i = 0; i < key_size; i++) { VLOG(0) << "actual sample size for " << i << "th key is " << ac_size[i]; VLOG(0) << "sampled neighbors are "; - std::string neighbor; + std::string neighbor, neighbor2; for (int j = 0; j < ac_size[i]; j++) { if (neighbor.size() > 0) neighbor += ";"; + if (neighbor2.size() > 0) neighbor2 += ";"; neighbor += std::to_string(res[i * sample_size + j]); + neighbor += std::to_string(res2[start + j]); } VLOG(0) << neighbor; + start += ac_size[i]; } delete[] res; + delete[] res2; delete[] ac_size; VLOG(0) << " ------------------"; } diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.h b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.h index 1c59f318517d0d..faf5485b736f42 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.h @@ -13,6 +13,8 @@ // limitations under the License. #include +#include +#include #include #pragma once #ifdef PADDLE_WITH_HETERPS @@ -374,6 +376,18 @@ __global__ void fill_dvalues(int64_t* d_shard_vals, int64_t* d_vals, } } +__global__ void fill_actual_vals(int64_t* vals, int64_t* actual_vals, + int* actual_sample_size, + int* cumsum_actual_sample_size, + int sample_size, int len) { + const size_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < len) { + for (int j = 0; j < actual_sample_size[i]; j++) { + actual_vals[cumsum_actual_sample_size[i] + j] = vals[sample_size * i + j]; + } + } +} + __global__ void node_query_example(GpuPsCommGraph graph, int start, int size, int64_t* res) { const size_t i = blockIdx.x * blockDim.x + threadIdx.x; @@ -846,6 +860,17 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2( fill_dvalues<<>>( d_shard_vals_ptr, val, d_shard_actual_sample_size_ptr, actual_sample_size, d_idx_ptr, sample_size, len); + + thrust::device_ptr t_actual_sample_size(actual_sample_size); + result->actual_val_mem = + memory::AllocShared(place, total_sample_size * sizeof(int64_t)); + result->actual_val = (int64_t*)(result->actual_val_mem)->ptr(); + + fill_actual_vals<<>>( + val, result->actual_val, actual_sample_size, + thrust::raw_pointer_cast(cumsum_actual_sample_size.data()), sample_size, + len); + for (int i = 0; i < total_gpu; ++i) { int shard_len = h_left[i] == -1 ? 0 : h_right[i] - h_left[i] + 1; if (shard_len == 0) { From 795008634f4406a313ce2886dfe53b4ce9172174 Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Fri, 29 Apr 2022 08:01:24 +0000 Subject: [PATCH 2/6] change vlog --- paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h b/paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h index 262bddc0676c11..2e0d7e73e0ba83 100644 --- a/paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h +++ b/paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h @@ -179,9 +179,9 @@ struct NeighborSampleResult { if (neighbor.size() > 0) neighbor += ";"; if (neighbor2.size() > 0) neighbor2 += ";"; neighbor += std::to_string(res[i * sample_size + j]); - neighbor += std::to_string(res2[start + j]); + neighbor2 += std::to_string(res2[start + j]); } - VLOG(0) << neighbor; + VLOG(0) << neighbor << " " << neighbor2; start += ac_size[i]; } delete[] res; From 24cb2597f7a386c5c821a9d3349b36e7886bc2d8 Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Fri, 29 Apr 2022 09:01:19 +0000 Subject: [PATCH 3/6] fix bug --- paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h | 3 ++- .../fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.h | 5 +++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h b/paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h index 2e0d7e73e0ba83..2d998387be9823 100644 --- a/paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h +++ b/paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h @@ -138,7 +138,8 @@ struct NeighborSampleQuery { } }; struct NeighborSampleResult { - int64_t *val, *actual_val; + int64_t *val; + int64_t *actual_val; int *actual_sample_size, sample_size, key_size; std::shared_ptr val_mem, actual_sample_size_mem; std::shared_ptr actual_val_mem; diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.h b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.h index faf5485b736f42..9e5fe1b9f8af2f 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.h @@ -862,10 +862,15 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2( d_idx_ptr, sample_size, len); thrust::device_ptr t_actual_sample_size(actual_sample_size); + int total_sample_size = + thrust::reduce(t_actual_sample_size, t_actual_sample_size + len); result->actual_val_mem = memory::AllocShared(place, total_sample_size * sizeof(int64_t)); result->actual_val = (int64_t*)(result->actual_val_mem)->ptr(); + thrust::device_vector cumsum_actual_sample_size(len); + thrust::exclusive_scan(t_actual_sample_size, t_actual_sample_size + len, + cumsum_actual_sample_size.begin(), 0); fill_actual_vals<<>>( val, result->actual_val, actual_sample_size, thrust::raw_pointer_cast(cumsum_actual_sample_size.data()), sample_size, From 7798771b7d94bda0530f3891787b0645a7fe542b Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Fri, 29 Apr 2022 09:14:25 +0000 Subject: [PATCH 4/6] bug fix --- .../fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.h b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.h index 9e5fe1b9f8af2f..7826974cbe9a70 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.h @@ -864,15 +864,15 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2( thrust::device_ptr t_actual_sample_size(actual_sample_size); int total_sample_size = thrust::reduce(t_actual_sample_size, t_actual_sample_size + len); - result->actual_val_mem = + result.actual_val_mem = memory::AllocShared(place, total_sample_size * sizeof(int64_t)); - result->actual_val = (int64_t*)(result->actual_val_mem)->ptr(); + result.actual_val = (int64_t*)(result->actual_val_mem)->ptr(); thrust::device_vector cumsum_actual_sample_size(len); thrust::exclusive_scan(t_actual_sample_size, t_actual_sample_size + len, cumsum_actual_sample_size.begin(), 0); fill_actual_vals<<>>( - val, result->actual_val, actual_sample_size, + val, result.actual_val, actual_sample_size, thrust::raw_pointer_cast(cumsum_actual_sample_size.data()), sample_size, len); From 641fcac1981cefddd2b3577756101b76945e4114 Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Fri, 29 Apr 2022 09:23:28 +0000 Subject: [PATCH 5/6] bug fix --- paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.h b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.h index 7826974cbe9a70..d619ca65ec4613 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.h @@ -866,7 +866,7 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2( thrust::reduce(t_actual_sample_size, t_actual_sample_size + len); result.actual_val_mem = memory::AllocShared(place, total_sample_size * sizeof(int64_t)); - result.actual_val = (int64_t*)(result->actual_val_mem)->ptr(); + result.actual_val = (int64_t*)(result.actual_val_mem)->ptr(); thrust::device_vector cumsum_actual_sample_size(len); thrust::exclusive_scan(t_actual_sample_size, t_actual_sample_size + len, From 7762561bc3a7854f3158b43c9938a7133e01c485 Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Fri, 29 Apr 2022 10:24:51 +0000 Subject: [PATCH 6/6] fix display test --- paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h b/paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h index 2d998387be9823..f3609fd13105d0 100644 --- a/paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h +++ b/paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h @@ -163,13 +163,18 @@ struct NeighborSampleResult { void display() { VLOG(0) << "in node sample result display ------------------"; int64_t *res = new int64_t[sample_size * key_size]; - int64_t *res2 = new int64_t[7]; cudaMemcpy(res, val, sample_size * key_size * sizeof(int64_t), cudaMemcpyDeviceToHost); - cudaMemcpy(res2, actual_val, 7 * sizeof(int64_t), cudaMemcpyDeviceToHost); int *ac_size = new int[key_size]; cudaMemcpy(ac_size, actual_sample_size, key_size * sizeof(int), cudaMemcpyDeviceToHost); // 3, 1, 3 + int total_sample_size = 0; + for (int i = 0; i < key_size; i++) { + total_sample_size += ac_size[i]; + } + int64_t *res2 = new int64_t[total_sample_size]; + cudaMemcpy(res2, actual_val, total_sample_size * sizeof(int64_t), + cudaMemcpyDeviceToHost); int start = 0; for (int i = 0; i < key_size; i++) {