2323#ifdef PADDLE_WITH_HETERPS
2424namespace paddle {
2525namespace framework {
26+ enum GraphTableType { EDGE_TABLE, FEATURE_TABLE };
2627class GpuPsGraphTable : public HeterComm <uint64_t , int64_t , int > {
2728 public:
28- GpuPsGraphTable (std::shared_ptr<HeterPsResource> resource, int topo_aware)
29+ int get_table_offset (int gpu_id, GraphTableType type, int idx) {
30+ int type_id = type;
31+ return gpu_id * (graph_table_num_ + feature_table_num_) +
32+ type_id * graph_table_num_ + idx;
33+ }
34+ GpuPsGraphTable (std::shared_ptr<HeterPsResource> resource, int topo_aware,
35+ int graph_table_num, int feature_table_num)
2936 : HeterComm<uint64_t , int64_t , int >(1 , resource) {
3037 load_factor_ = 0.25 ;
3138 rw_lock.reset (new pthread_rwlock_t ());
39+ this ->graph_table_num_ = graph_table_num;
40+ this ->feature_table_num_ = feature_table_num;
3241 gpu_num = resource_->total_device ();
3342 memset (global_device_map, -1 , sizeof (global_device_map));
43+ for (auto &table : tables_) {
44+ delete table;
45+ table = NULL ;
46+ }
47+ tables_ = std::vector<Table *>(
48+ gpu_num * (graph_table_num + feature_table_num), NULL );
49+ sample_status = std::vector<int *>(gpu_num * graph_table_num, NULL );
3450 for (int i = 0 ; i < gpu_num; i++) {
35- gpu_graph_list.push_back (GpuPsCommGraph ());
3651 global_device_map[resource_->dev_id (i)] = i;
37- sample_status.push_back (NULL );
38- tables_.push_back (NULL );
52+ for (int j = 0 ; j < graph_table_num; j++) {
53+ gpu_graph_list_.push_back (GpuPsCommGraph ());
54+ }
3955 }
4056 cpu_table_status = -1 ;
4157 if (topo_aware) {
@@ -89,21 +105,23 @@ class GpuPsGraphTable : public HeterComm<uint64_t, int64_t, int> {
89105 // end_graph_sampling();
90106 // }
91107 }
92- void build_graph_on_single_gpu (GpuPsCommGraph &g, int gpu_id);
93- void clear_graph_info (int gpu_id);
94- void build_graph_from_cpu (std::vector<GpuPsCommGraph> &cpu_node_list);
108+ void build_graph_on_single_gpu (GpuPsCommGraph &g, int gpu_id, int idx);
109+ void clear_graph_info (int gpu_id, int index);
110+ void clear_graph_info (int index);
111+ void build_graph_from_cpu (std::vector<GpuPsCommGraph> &cpu_node_list,
112+ int idx);
95113 NodeQueryResult graph_node_sample (int gpu_id, int sample_size);
96114 NeighborSampleResult graph_neighbor_sample_v3 (NeighborSampleQuery q,
97115 bool cpu_switch);
98- NeighborSampleResult graph_neighbor_sample (int gpu_id, int64_t *key,
116+ NeighborSampleResult graph_neighbor_sample (int gpu_id, int idx, int64_t *key,
99117 int sample_size, int len);
100- NeighborSampleResult graph_neighbor_sample_v2 (int gpu_id, int64_t *key ,
101- int sample_size , int len ,
102- bool cpu_query_switch);
118+ NeighborSampleResult graph_neighbor_sample_v2 (int gpu_id, int idx ,
119+ int64_t *key , int sample_size ,
120+ int len, bool cpu_query_switch);
103121 void init_sample_status ();
104122 void free_sample_status ();
105- NodeQueryResult query_node_list (int gpu_id, int start , int query_size);
106- void clear_graph_info ( );
123+ NodeQueryResult query_node_list (int gpu_id, int idx , int start,
124+ int query_size );
107125 void display_sample_res (void *key, void *val, int len, int sample_len);
108126 void move_neighbor_sample_result_to_source_gpu (int gpu_id, int gpu_num,
109127 int sample_size, int *h_left,
@@ -112,12 +130,13 @@ class GpuPsGraphTable : public HeterComm<uint64_t, int64_t, int> {
112130 int *actual_sample_size);
113131 int init_cpu_table (const paddle::distributed::GraphParameter &graph);
114132 int gpu_num;
115- std::vector<GpuPsCommGraph> gpu_graph_list;
133+ int graph_table_num_, feature_table_num_;
134+ std::vector<GpuPsCommGraph> gpu_graph_list_;
116135 int global_device_map[32 ];
117136 std::vector<int *> sample_status;
118137 const int parallel_sample_size = 1 ;
119138 const int dim_y = 256 ;
120- std::shared_ptr<paddle::distributed::GraphTable> cpu_graph_table ;
139+ std::shared_ptr<paddle::distributed::GraphTable> cpu_graph_table_ ;
121140 std::shared_ptr<pthread_rwlock_t > rw_lock;
122141 mutable std::mutex mutex_;
123142 std::condition_variable cv_;
0 commit comments