|
13 | 13 | // limitations under the License. |
14 | 14 |
|
15 | 15 | #include <thrust/device_vector.h> |
| 16 | +#include <thrust/reduce.h> |
| 17 | +#include <thrust/scan.h> |
16 | 18 | #include <functional> |
17 | 19 | #pragma once |
18 | 20 | #ifdef PADDLE_WITH_HETERPS |
@@ -374,6 +376,18 @@ __global__ void fill_dvalues(int64_t* d_shard_vals, int64_t* d_vals, |
374 | 376 | } |
375 | 377 | } |
376 | 378 |
|
| 379 | +__global__ void fill_actual_vals(int64_t* vals, int64_t* actual_vals, |
| 380 | + int* actual_sample_size, |
| 381 | + int* cumsum_actual_sample_size, |
| 382 | + int sample_size, int len) { |
| 383 | + const size_t i = blockIdx.x * blockDim.x + threadIdx.x; |
| 384 | + if (i < len) { |
| 385 | + for (int j = 0; j < actual_sample_size[i]; j++) { |
| 386 | + actual_vals[cumsum_actual_sample_size[i] + j] = vals[sample_size * i + j]; |
| 387 | + } |
| 388 | + } |
| 389 | +} |
| 390 | + |
377 | 391 | __global__ void node_query_example(GpuPsCommGraph graph, int start, int size, |
378 | 392 | int64_t* res) { |
379 | 393 | const size_t i = blockIdx.x * blockDim.x + threadIdx.x; |
@@ -846,6 +860,22 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2( |
846 | 860 | fill_dvalues<<<grid_size, block_size_, 0, stream>>>( |
847 | 861 | d_shard_vals_ptr, val, d_shard_actual_sample_size_ptr, actual_sample_size, |
848 | 862 | d_idx_ptr, sample_size, len); |
| 863 | + |
| 864 | + thrust::device_ptr<int> t_actual_sample_size(actual_sample_size); |
| 865 | + int total_sample_size = |
| 866 | + thrust::reduce(t_actual_sample_size, t_actual_sample_size + len); |
| 867 | + result.actual_val_mem = |
| 868 | + memory::AllocShared(place, total_sample_size * sizeof(int64_t)); |
| 869 | + result.actual_val = (int64_t*)(result.actual_val_mem)->ptr(); |
| 870 | + |
| 871 | + thrust::device_vector<int> cumsum_actual_sample_size(len); |
| 872 | + thrust::exclusive_scan(t_actual_sample_size, t_actual_sample_size + len, |
| 873 | + cumsum_actual_sample_size.begin(), 0); |
| 874 | + fill_actual_vals<<<grid_size, block_size_, 0, stream>>>( |
| 875 | + val, result.actual_val, actual_sample_size, |
| 876 | + thrust::raw_pointer_cast(cumsum_actual_sample_size.data()), sample_size, |
| 877 | + len); |
| 878 | + |
849 | 879 | for (int i = 0; i < total_gpu; ++i) { |
850 | 880 | int shard_len = h_left[i] == -1 ? 0 : h_right[i] - h_left[i] + 1; |
851 | 881 | if (shard_len == 0) { |
|
0 commit comments