Skip to content

Commit 7cfe661

Browse files
authored
Merge pull request #26 from DesmonDay/gpu_graph_engine2
Add actual neighbor sample result
2 parents acb8ac0 + 7762561 commit 7cfe661

File tree

2 files changed

+47
-2
lines changed

2 files changed

+47
-2
lines changed

paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,12 @@ struct NeighborSampleQuery {
139139
};
140140
struct NeighborSampleResult {
141141
int64_t *val;
142+
int64_t *actual_val;
142143
int *actual_sample_size, sample_size, key_size;
143144
std::shared_ptr<memory::Allocation> val_mem, actual_sample_size_mem;
145+
std::shared_ptr<memory::Allocation> actual_val_mem;
144146
int64_t *get_val() { return val; }
147+
int64_t *get_actual_val() { return actual_val; }
145148
int *get_actual_sample_size() { return actual_sample_size; }
146149
int get_sample_size() { return sample_size; }
147150
int get_key_size() { return key_size; }
@@ -165,18 +168,30 @@ struct NeighborSampleResult {
165168
int *ac_size = new int[key_size];
166169
cudaMemcpy(ac_size, actual_sample_size, key_size * sizeof(int),
167170
cudaMemcpyDeviceToHost); // 3, 1, 3
171+
int total_sample_size = 0;
172+
for (int i = 0; i < key_size; i++) {
173+
total_sample_size += ac_size[i];
174+
}
175+
int64_t *res2 = new int64_t[total_sample_size];
176+
cudaMemcpy(res2, actual_val, total_sample_size * sizeof(int64_t),
177+
cudaMemcpyDeviceToHost);
168178

179+
int start = 0;
169180
for (int i = 0; i < key_size; i++) {
170181
VLOG(0) << "actual sample size for " << i << "th key is " << ac_size[i];
171182
VLOG(0) << "sampled neighbors are ";
172-
std::string neighbor;
183+
std::string neighbor, neighbor2;
173184
for (int j = 0; j < ac_size[i]; j++) {
174185
if (neighbor.size() > 0) neighbor += ";";
186+
if (neighbor2.size() > 0) neighbor2 += ";";
175187
neighbor += std::to_string(res[i * sample_size + j]);
188+
neighbor2 += std::to_string(res2[start + j]);
176189
}
177-
VLOG(0) << neighbor;
190+
VLOG(0) << neighbor << " " << neighbor2;
191+
start += ac_size[i];
178192
}
179193
delete[] res;
194+
delete[] res2;
180195
delete[] ac_size;
181196
VLOG(0) << " ------------------";
182197
}

paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
// limitations under the License.
1414

1515
#include <thrust/device_vector.h>
16+
#include <thrust/reduce.h>
17+
#include <thrust/scan.h>
1618
#include <functional>
1719
#pragma once
1820
#ifdef PADDLE_WITH_HETERPS
@@ -374,6 +376,18 @@ __global__ void fill_dvalues(int64_t* d_shard_vals, int64_t* d_vals,
374376
}
375377
}
376378

379+
__global__ void fill_actual_vals(int64_t* vals, int64_t* actual_vals,
380+
int* actual_sample_size,
381+
int* cumsum_actual_sample_size,
382+
int sample_size, int len) {
383+
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
384+
if (i < len) {
385+
for (int j = 0; j < actual_sample_size[i]; j++) {
386+
actual_vals[cumsum_actual_sample_size[i] + j] = vals[sample_size * i + j];
387+
}
388+
}
389+
}
390+
377391
__global__ void node_query_example(GpuPsCommGraph graph, int start, int size,
378392
int64_t* res) {
379393
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
@@ -846,6 +860,22 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2(
846860
fill_dvalues<<<grid_size, block_size_, 0, stream>>>(
847861
d_shard_vals_ptr, val, d_shard_actual_sample_size_ptr, actual_sample_size,
848862
d_idx_ptr, sample_size, len);
863+
864+
thrust::device_ptr<int> t_actual_sample_size(actual_sample_size);
865+
int total_sample_size =
866+
thrust::reduce(t_actual_sample_size, t_actual_sample_size + len);
867+
result.actual_val_mem =
868+
memory::AllocShared(place, total_sample_size * sizeof(int64_t));
869+
result.actual_val = (int64_t*)(result.actual_val_mem)->ptr();
870+
871+
thrust::device_vector<int> cumsum_actual_sample_size(len);
872+
thrust::exclusive_scan(t_actual_sample_size, t_actual_sample_size + len,
873+
cumsum_actual_sample_size.begin(), 0);
874+
fill_actual_vals<<<grid_size, block_size_, 0, stream>>>(
875+
val, result.actual_val, actual_sample_size,
876+
thrust::raw_pointer_cast(cumsum_actual_sample_size.data()), sample_size,
877+
len);
878+
849879
for (int i = 0; i < total_gpu; ++i) {
850880
int shard_len = h_left[i] == -1 ? 0 : h_right[i] - h_left[i] + 1;
851881
if (shard_len == 0) {

0 commit comments

Comments
 (0)