Skip to content

Commit 676a92c

Browse files
authored
Merge pull request #1 from Thunderbrook/gpugraph_0523
[GpuGraph] GraphInsGenerator
2 parents e726960 + 3f3185e commit 676a92c

File tree

15 files changed

+367
-116
lines changed

15 files changed

+367
-116
lines changed

paddle/fluid/distributed/ps/table/common_graph_table.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -566,7 +566,7 @@ class GraphTable : public Table {
566566
int32_t dump_edges_to_ssd(int idx);
567567
int32_t get_partition_num(int idx) { return partitions[idx].size(); }
568568
std::vector<int64_t> get_partition(int idx, int index) {
569-
if (idx >= partitions.size() || index >= partitions[idx].size())
569+
if (idx >= (int)partitions.size() || index >= (int)partitions[idx].size())
570570
return std::vector<int64_t>();
571571
return partitions[idx][index];
572572
}

paddle/fluid/framework/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,9 @@ if(WITH_DISTRIBUTE)
321321
device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog
322322
index_sampler index_wrapper sampler index_dataset_proto
323323
lod_rank_table fs shell fleet_wrapper heter_wrapper box_wrapper metrics lodtensor_printer feed_fetch_method
324-
graph_to_program_pass variable_helper timer monitor heter_service_proto fleet heter_server brpc fleet_executor)
324+
graph_to_program_pass variable_helper timer monitor
325+
heter_service_proto fleet heter_server brpc fleet_executor
326+
graph_gpu_wrapper)
325327
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor -Wno-error=parentheses")
326328
if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0)
327329
set(DISTRIBUTE_COMPILE_FLAGS

paddle/fluid/framework/data_feed.cc

Lines changed: 53 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,34 @@ DLManager& global_dlmanager_pool() {
3838
return manager;
3939
}
4040

41+
void GraphDataGenerator::AllocResource(const paddle::platform::Place& place,
42+
std::vector<LoDTensor*> feed_vec,
43+
std::vector<int64_t>* h_device_keys) {
44+
place_ = place;
45+
gpuid_ = place_.GetDeviceId();
46+
VLOG(3) << "gpuid " << gpuid_;
47+
stream_ = dynamic_cast<platform::CUDADeviceContext*>(
48+
platform::DeviceContextPool::Instance().Get(place))
49+
->stream();
50+
feed_vec_ = feed_vec;
51+
h_device_keys_ = h_device_keys;
52+
device_key_size_ = h_device_keys_->size();
53+
d_device_keys_ =
54+
memory::AllocShared(place_, device_key_size_ * sizeof(int64_t));
55+
CUDA_CHECK(cudaMemcpyAsync(d_device_keys_->ptr(), h_device_keys_->data(),
56+
device_key_size_ * sizeof(int64_t),
57+
cudaMemcpyHostToDevice, stream_));
58+
d_prefix_sum_ =
59+
memory::AllocShared(place_, (sample_key_size_ + 1) * sizeof(int64_t));
60+
int64_t* d_prefix_sum_ptr = reinterpret_cast<int64_t*>(d_prefix_sum_->ptr());
61+
cudaMemsetAsync(d_prefix_sum_ptr, 0, (sample_key_size_ + 1) * sizeof(int64_t),
62+
stream_);
63+
cursor_ = 0;
64+
device_keys_ = reinterpret_cast<int64_t*>(d_device_keys_->ptr());
65+
;
66+
cudaStreamSynchronize(stream_);
67+
}
68+
4169
class BufferedLineFileReader {
4270
typedef std::function<bool()> SampleFunc;
4371
static const int MAX_FILE_BUFF_SIZE = 4 * 1024 * 1024;
@@ -2065,6 +2093,7 @@ void SlotRecordInMemoryDataFeed::Init(const DataFeedDesc& data_feed_desc) {
20652093
} else {
20662094
so_parser_name_.clear();
20672095
}
2096+
gpu_graph_data_generator_.SetConfig(data_feed_desc);
20682097
}
20692098

20702099
void SlotRecordInMemoryDataFeed::LoadIntoMemory() {
@@ -2589,34 +2618,40 @@ bool SlotRecordInMemoryDataFeed::Start() {
25892618
#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS)
25902619
CHECK(paddle::platform::is_gpu_place(this->place_));
25912620
pack_ = BatchGpuPackMgr().get(this->GetPlace(), used_slots_info_);
2621+
gpu_graph_data_generator_.AllocResource(this->place_, feed_vec_,
2622+
h_device_keys_);
25922623
#endif
25932624
return true;
25942625
}
25952626

25962627
int SlotRecordInMemoryDataFeed::Next() {
25972628
#ifdef _LINUX
25982629
this->CheckStart();
2599-
2600-
VLOG(3) << "enable heter next: " << offset_index_
2601-
<< " batch_offsets: " << batch_offsets_.size();
2602-
if (offset_index_ >= batch_offsets_.size()) {
2603-
VLOG(3) << "offset_index: " << offset_index_
2630+
if (!gpu_graph_mode_) {
2631+
VLOG(3) << "enable heter next: " << offset_index_
26042632
<< " batch_offsets: " << batch_offsets_.size();
2605-
return 0;
2606-
}
2607-
auto& batch = batch_offsets_[offset_index_++];
2608-
this->batch_size_ = batch.second;
2609-
VLOG(3) << "batch_size_=" << this->batch_size_
2610-
<< ", thread_id=" << thread_id_;
2611-
if (this->batch_size_ != 0) {
2612-
PutToFeedVec(&records_[batch.first], this->batch_size_);
2633+
if (offset_index_ >= batch_offsets_.size()) {
2634+
VLOG(3) << "offset_index: " << offset_index_
2635+
<< " batch_offsets: " << batch_offsets_.size();
2636+
return 0;
2637+
}
2638+
auto& batch = batch_offsets_[offset_index_++];
2639+
this->batch_size_ = batch.second;
2640+
VLOG(3) << "batch_size_=" << this->batch_size_
2641+
<< ", thread_id=" << thread_id_;
2642+
if (this->batch_size_ != 0) {
2643+
PutToFeedVec(&records_[batch.first], this->batch_size_);
2644+
} else {
2645+
VLOG(3) << "finish reading for heterps, batch size zero, thread_id="
2646+
<< thread_id_;
2647+
}
2648+
VLOG(3) << "enable heter next: " << offset_index_
2649+
<< " batch_offsets: " << batch_offsets_.size()
2650+
<< " baych_size: " << this->batch_size_;
26132651
} else {
2614-
VLOG(3) << "finish reading for heterps, batch size zero, thread_id="
2615-
<< thread_id_;
2652+
VLOG(3) << "datafeed in gpu graph mode";
2653+
this->batch_size_ = gpu_graph_data_generator_.GenerateBatch();
26162654
}
2617-
VLOG(3) << "enable heter next: " << offset_index_
2618-
<< " batch_offsets: " << batch_offsets_.size()
2619-
<< " baych_size: " << this->batch_size_;
26202655

26212656
return this->batch_size_;
26222657
#else

paddle/fluid/framework/data_feed.cu

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@ limitations under the License. */
1717
#endif
1818
#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS)
1919

20+
#include "cub/cub.cuh"
2021
#include "paddle/fluid/framework/data_feed.h"
21-
22+
#include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h"
23+
#include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h"
2224
namespace paddle {
2325
namespace framework {
2426

@@ -144,6 +146,89 @@ void SlotRecordInMemoryDataFeed::CopyForTensor(
144146
cudaStreamSynchronize(stream);
145147
}
146148

149+
__global__ void GraphFillIdKernel(int64_t *id_tensor, int *actual_sample_size,
150+
int64_t *prefix_sum, int64_t *device_key,
151+
int64_t *neighbors, int sample_size,
152+
int len) {
153+
CUDA_KERNEL_LOOP(idx, len) {
154+
for (int k = 0; k < actual_sample_size[idx]; k++) {
155+
int offset = (prefix_sum[idx] + k) * 2;
156+
id_tensor[offset] = device_key[idx];
157+
id_tensor[offset + 1] = neighbors[idx * sample_size + k];
158+
}
159+
}
160+
}
161+
162+
__global__ void GraphFillCVMKernel(int64_t *tensor, int len) {
163+
CUDA_KERNEL_LOOP(idx, len) { tensor[idx] = 1; }
164+
}
165+
166+
void GraphDataGenerator::FeedGraphIns(size_t cursor, int len,
167+
NeighborSampleResult &sample_res) {
168+
size_t temp_storage_bytes = 0;
169+
int *d_actual_sample_size = sample_res.actual_sample_size;
170+
int64_t *d_neighbors = sample_res.val;
171+
int64_t *d_prefix_sum = reinterpret_cast<int64_t *>(d_prefix_sum_->ptr());
172+
CUDA_CHECK(cub::DeviceScan::InclusiveSum(NULL, temp_storage_bytes,
173+
d_actual_sample_size,
174+
d_prefix_sum + 1, len, stream_));
175+
auto d_temp_storage = memory::Alloc(place_, temp_storage_bytes);
176+
177+
CUDA_CHECK(cub::DeviceScan::InclusiveSum(
178+
d_temp_storage->ptr(), temp_storage_bytes, d_actual_sample_size,
179+
d_prefix_sum + 1, len, stream_));
180+
cudaStreamSynchronize(stream_);
181+
int64_t total_ins = 0;
182+
cudaMemcpyAsync(&total_ins, d_prefix_sum + len, sizeof(int64_t),
183+
cudaMemcpyDeviceToHost, stream_);
184+
185+
total_ins *= 2;
186+
id_tensor_ptr_ =
187+
feed_vec_[0]->mutable_data<int64_t>({total_ins, 1}, this->place_);
188+
show_tensor_ptr_ =
189+
feed_vec_[1]->mutable_data<int64_t>({total_ins}, this->place_);
190+
clk_tensor_ptr_ =
191+
feed_vec_[2]->mutable_data<int64_t>({total_ins}, this->place_);
192+
193+
GraphFillIdKernel<<<GET_BLOCKS(len), CUDA_NUM_THREADS, 0, stream_>>>(
194+
id_tensor_ptr_, d_actual_sample_size, d_prefix_sum,
195+
device_keys_ + cursor_, d_neighbors, walk_degree_, len);
196+
GraphFillCVMKernel<<<GET_BLOCKS(len), CUDA_NUM_THREADS, 0, stream_>>>(
197+
show_tensor_ptr_, total_ins);
198+
GraphFillCVMKernel<<<GET_BLOCKS(len), CUDA_NUM_THREADS, 0, stream_>>>(
199+
clk_tensor_ptr_, total_ins);
200+
201+
offset_.clear();
202+
offset_.push_back(0);
203+
offset_.push_back(total_ins);
204+
LoD lod{offset_};
205+
feed_vec_[0]->set_lod(lod);
206+
// feed_vec_[1]->set_lod(lod);
207+
// feed_vec_[2]->set_lod(lod);
208+
cudaStreamSynchronize(stream_);
209+
}
210+
211+
int GraphDataGenerator::GenerateBatch() {
212+
// GpuPsGraphTable *g = (GpuPsGraphTable *)(gpu_graph_ptr->graph_table);
213+
platform::CUDADeviceGuard guard(gpuid_);
214+
auto gpu_graph_ptr = GraphGpuWrapper::GetInstance();
215+
int tmp_len = cursor_ + sample_key_size_ > device_key_size_
216+
? device_key_size_ - cursor_
217+
: sample_key_size_;
218+
VLOG(3) << "device key size: " << device_key_size_
219+
<< " this batch: " << tmp_len << " cursor: " << cursor_
220+
<< " sample_key_size_: " << sample_key_size_;
221+
if (tmp_len == 0) {
222+
return 0;
223+
}
224+
int total_instance = 1;
225+
auto sample_res = gpu_graph_ptr->graph_neighbor_sample(
226+
gpuid_, device_keys_ + cursor_, walk_degree_, tmp_len);
227+
FeedGraphIns(cursor_, tmp_len, sample_res);
228+
cursor_ += tmp_len;
229+
return 1;
230+
}
231+
147232
} // namespace framework
148233
} // namespace paddle
149234
#endif

paddle/fluid/framework/data_feed.h

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ namespace framework {
5656
class DataFeedDesc;
5757
class Scope;
5858
class Variable;
59+
class NeighborSampleResult;
5960
} // namespace framework
6061
} // namespace paddle
6162

@@ -774,6 +775,38 @@ class DLManager {
774775
std::map<std::string, DLHandle> handle_map_;
775776
};
776777

778+
class GraphDataGenerator {
779+
public:
780+
GraphDataGenerator() {};
781+
~GraphDataGenerator() {};
782+
void SetConfig(const paddle::framework::DataFeedDesc& data_feed_desc) {
783+
walk_degree_ = 1;
784+
walk_len_ = 1;
785+
sample_key_size_ = 8000;
786+
};
787+
void AllocResource(const paddle::platform::Place& place, std::vector<LoDTensor*> feed_vec, std::vector<int64_t>* h_device_keys);
788+
void FeedGraphIns(size_t cursor, int len, NeighborSampleResult& sample_res);
789+
int GenerateBatch();
790+
protected:
791+
int walk_degree_ = 1;
792+
int walk_len_ = 1;
793+
int sample_key_size_;
794+
int gpuid_;
795+
size_t device_key_size_;
796+
size_t cursor_;
797+
int64_t* device_keys_;
798+
int64_t* id_tensor_ptr_;
799+
int64_t* show_tensor_ptr_;
800+
int64_t* clk_tensor_ptr_;
801+
cudaStream_t stream_;
802+
paddle::platform::Place place_;
803+
std::vector<LoDTensor*> feed_vec_;
804+
std::vector<int64_t>* h_device_keys_;
805+
std::vector<size_t> offset_;
806+
std::shared_ptr<phi::Allocation> d_prefix_sum_ = nullptr;
807+
std::shared_ptr<phi::Allocation> d_device_keys_ = nullptr;
808+
};
809+
777810
class DataFeed {
778811
public:
779812
DataFeed() {
@@ -836,6 +869,12 @@ class DataFeed {
836869
virtual void SetParseLogKey(bool parse_logkey) {}
837870
virtual void SetEnablePvMerge(bool enable_pv_merge) {}
838871
virtual void SetCurrentPhase(int current_phase) {}
872+
virtual void SetDeviceKeys(std::vector<int64_t>* device_keys) {
873+
h_device_keys_ = device_keys;
874+
}
875+
virtual void SetGpuGraphMode(int gpu_graph_mode) {
876+
gpu_graph_mode_ = gpu_graph_mode;
877+
}
839878
virtual void SetFileListMutex(std::mutex* mutex) {
840879
mutex_for_pick_file_ = mutex;
841880
}
@@ -919,6 +958,9 @@ class DataFeed {
919958

920959
// The input type of pipe reader, 0 for one sample, 1 for one batch
921960
int input_type_;
961+
int gpu_graph_mode_ = 0;
962+
std::vector<int64_t>* h_device_keys_;
963+
GraphDataGenerator gpu_graph_data_generator_;
922964
};
923965

924966
// PrivateQueueDataFeed is the base virtual class for ohther DataFeeds.

paddle/fluid/framework/data_set.cc

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
#ifdef PADDLE_WITH_PSCORE
2727
#include "paddle/fluid/distributed/ps/wrapper/fleet.h"
28+
#include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h"
2829
#endif
2930

3031
#if defined _WIN32 || defined __APPLE__
@@ -417,12 +418,30 @@ void DatasetImpl<T>::LoadIntoMemory() {
417418
platform::Timer timeline;
418419
timeline.Start();
419420
std::vector<std::thread> load_threads;
420-
for (int64_t i = 0; i < thread_num_; ++i) {
421-
load_threads.push_back(std::thread(
422-
&paddle::framework::DataFeed::LoadIntoMemory, readers_[i].get()));
423-
}
424-
for (std::thread& t : load_threads) {
425-
t.join();
421+
if (gpu_graph_mode_) {
422+
VLOG(0) << "in gpu_graph_mode";
423+
auto gpu_graph_ptr = GraphGpuWrapper::GetInstance();
424+
gpu_graph_device_keys_ = gpu_graph_ptr->get_all_id(0, 0, thread_num_);
425+
426+
for (size_t i = 0; i < gpu_graph_device_keys_.size(); i++) {
427+
VLOG(0) << "gpu_graph_device_keys_[" << i << "] = " << gpu_graph_device_keys_[i].size();
428+
for (size_t j = 0; j < gpu_graph_device_keys_[i].size(); j++) {
429+
gpu_graph_total_keys_.push_back(gpu_graph_device_keys_[i][j]);
430+
}
431+
}
432+
for (size_t i = 0; i < readers_.size(); i++) {
433+
readers_[i]->SetDeviceKeys(&gpu_graph_device_keys_[i]);
434+
readers_[i]->SetGpuGraphMode(gpu_graph_mode_);
435+
}
436+
437+
} else {
438+
for (int64_t i = 0; i < thread_num_; ++i) {
439+
load_threads.push_back(std::thread(
440+
&paddle::framework::DataFeed::LoadIntoMemory, readers_[i].get()));
441+
}
442+
for (std::thread& t : load_threads) {
443+
t.join();
444+
}
426445
}
427446
input_channel_->Close();
428447
int64_t in_chan_size = input_channel_->Size();

paddle/fluid/framework/data_set.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,6 @@ class Dataset {
158158
virtual void DynamicAdjustReadersNum(int thread_num) = 0;
159159
// set fleet send sleep seconds
160160
virtual void SetFleetSendSleepSeconds(int seconds) = 0;
161-
162161
protected:
163162
virtual int ReceiveFromClient(int msg_type, int client_id,
164163
const std::string& msg) = 0;
@@ -263,7 +262,9 @@ class DatasetImpl : public Dataset {
263262
return multi_consume_channel_;
264263
}
265264
}
266-
265+
std::vector<int64_t>& GetGpuGraphTotalKeys() {
266+
return gpu_graph_total_keys_;
267+
}
267268
Channel<T>& GetInputChannelRef() { return input_channel_; }
268269

269270
protected:
@@ -322,6 +323,9 @@ class DatasetImpl : public Dataset {
322323
std::vector<std::shared_ptr<ThreadPool>> consume_task_pool_;
323324
std::vector<T> input_records_; // only for paddleboxdatafeed
324325
bool enable_heterps_ = false;
326+
int gpu_graph_mode_ = 1;
327+
std::vector<std::vector<int64_t>> gpu_graph_device_keys_;
328+
std::vector<int64_t> gpu_graph_total_keys_;
325329
};
326330

327331
// use std::vector<MultiSlotType> or Record as data type

0 commit comments

Comments
 (0)