@@ -17,8 +17,10 @@ limitations under the License. */
1717#endif
1818#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS)
1919
20+ #include " cub/cub.cuh"
2021#include " paddle/fluid/framework/data_feed.h"
21-
22+ #include " paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h"
23+ #include " paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h"
2224namespace paddle {
2325namespace framework {
2426
@@ -144,6 +146,89 @@ void SlotRecordInMemoryDataFeed::CopyForTensor(
144146 cudaStreamSynchronize (stream);
145147}
146148
149+ __global__ void GraphFillIdKernel (int64_t *id_tensor, int *actual_sample_size,
150+ int64_t *prefix_sum, int64_t *device_key,
151+ int64_t *neighbors, int sample_size,
152+ int len) {
153+ CUDA_KERNEL_LOOP (idx, len) {
154+ for (int k = 0 ; k < actual_sample_size[idx]; k++) {
155+ int offset = (prefix_sum[idx] + k) * 2 ;
156+ id_tensor[offset] = device_key[idx];
157+ id_tensor[offset + 1 ] = neighbors[idx * sample_size + k];
158+ }
159+ }
160+ }
161+
162+ __global__ void GraphFillCVMKernel (int64_t *tensor, int len) {
163+ CUDA_KERNEL_LOOP (idx, len) { tensor[idx] = 1 ; }
164+ }
165+
166+ void GraphDataGenerator::FeedGraphIns (size_t cursor, int len,
167+ NeighborSampleResult &sample_res) {
168+ size_t temp_storage_bytes = 0 ;
169+ int *d_actual_sample_size = sample_res.actual_sample_size ;
170+ int64_t *d_neighbors = sample_res.val ;
171+ int64_t *d_prefix_sum = reinterpret_cast <int64_t *>(d_prefix_sum_->ptr ());
172+ CUDA_CHECK (cub::DeviceScan::InclusiveSum (NULL , temp_storage_bytes,
173+ d_actual_sample_size,
174+ d_prefix_sum + 1 , len, stream_));
175+ auto d_temp_storage = memory::Alloc (place_, temp_storage_bytes);
176+
177+ CUDA_CHECK (cub::DeviceScan::InclusiveSum (
178+ d_temp_storage->ptr (), temp_storage_bytes, d_actual_sample_size,
179+ d_prefix_sum + 1 , len, stream_));
180+ cudaStreamSynchronize (stream_);
181+ int64_t total_ins = 0 ;
182+ cudaMemcpyAsync (&total_ins, d_prefix_sum + len, sizeof (int64_t ),
183+ cudaMemcpyDeviceToHost, stream_);
184+
185+ total_ins *= 2 ;
186+ id_tensor_ptr_ =
187+ feed_vec_[0 ]->mutable_data <int64_t >({total_ins, 1 }, this ->place_ );
188+ show_tensor_ptr_ =
189+ feed_vec_[1 ]->mutable_data <int64_t >({total_ins}, this ->place_ );
190+ clk_tensor_ptr_ =
191+ feed_vec_[2 ]->mutable_data <int64_t >({total_ins}, this ->place_ );
192+
193+ GraphFillIdKernel<<<GET_BLOCKS(len), CUDA_NUM_THREADS, 0 , stream_>>> (
194+ id_tensor_ptr_, d_actual_sample_size, d_prefix_sum,
195+ device_keys_ + cursor_, d_neighbors, walk_degree_, len);
196+ GraphFillCVMKernel<<<GET_BLOCKS(len), CUDA_NUM_THREADS, 0 , stream_>>> (
197+ show_tensor_ptr_, total_ins);
198+ GraphFillCVMKernel<<<GET_BLOCKS(len), CUDA_NUM_THREADS, 0 , stream_>>> (
199+ clk_tensor_ptr_, total_ins);
200+
201+ offset_.clear ();
202+ offset_.push_back (0 );
203+ offset_.push_back (total_ins);
204+ LoD lod{offset_};
205+ feed_vec_[0 ]->set_lod (lod);
206+ // feed_vec_[1]->set_lod(lod);
207+ // feed_vec_[2]->set_lod(lod);
208+ cudaStreamSynchronize (stream_);
209+ }
210+
211+ int GraphDataGenerator::GenerateBatch () {
212+ // GpuPsGraphTable *g = (GpuPsGraphTable *)(gpu_graph_ptr->graph_table);
213+ platform::CUDADeviceGuard guard (gpuid_);
214+ auto gpu_graph_ptr = GraphGpuWrapper::GetInstance ();
215+ int tmp_len = cursor_ + sample_key_size_ > device_key_size_
216+ ? device_key_size_ - cursor_
217+ : sample_key_size_;
218+ VLOG (3 ) << " device key size: " << device_key_size_
219+ << " this batch: " << tmp_len << " cursor: " << cursor_
220+ << " sample_key_size_: " << sample_key_size_;
221+ if (tmp_len == 0 ) {
222+ return 0 ;
223+ }
224+ int total_instance = 1 ;
225+ auto sample_res = gpu_graph_ptr->graph_neighbor_sample (
226+ gpuid_, device_keys_ + cursor_, walk_degree_, tmp_len);
227+ FeedGraphIns (cursor_, tmp_len, sample_res);
228+ cursor_ += tmp_len;
229+ return 1 ;
230+ }
231+
147232} // namespace framework
148233} // namespace paddle
149234#endif
0 commit comments