Skip to content

Commit fc28b23

Browse files
authored
Merge pull request #3 from seemingwang/gpu_graph_engine2
split graph table
2 parents 63e501d + a9b5445 commit fc28b23

File tree

9 files changed

+235
-268
lines changed

9 files changed

+235
-268
lines changed

paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -123,21 +123,25 @@ node_list[8]-> node_id:17, neighbor_size:1, neighbor_offset:15
123123
*/
124124
struct NeighborSampleQuery {
125125
int gpu_id;
126-
int64_t *key;
127-
int sample_size;
126+
int table_idx;
127+
int64_t *src_nodes;
128128
int len;
129-
void initialize(int gpu_id, int64_t key, int sample_size, int len) {
129+
int sample_size;
130+
void initialize(int gpu_id, int table_idx, int64_t src_nodes, int sample_size,
131+
int len) {
132+
this->table_idx = table_idx;
130133
this->gpu_id = gpu_id;
131-
this->key = (int64_t *)key;
134+
this->src_nodes = (int64_t *)src_nodes;
132135
this->sample_size = sample_size;
133136
this->len = len;
134137
}
135138
void display() {
136139
int64_t *sample_keys = new int64_t[len];
137140
VLOG(0) << "device_id " << gpu_id << " sample_size = " << sample_size;
138-
VLOG(0) << "there are " << len << " keys ";
141+
VLOG(0) << "there are " << len << " keys to sample for graph " << table_idx;
139142
std::string key_str;
140-
cudaMemcpy(sample_keys, key, len * sizeof(int64_t), cudaMemcpyDeviceToHost);
143+
cudaMemcpy(sample_keys, src_nodes, len * sizeof(int64_t),
144+
cudaMemcpyDeviceToHost);
141145

142146
for (int i = 0; i < len; i++) {
143147
if (key_str.size() > 0) key_str += ";";
@@ -212,7 +216,7 @@ struct NeighborSampleResult {
212216
std::vector<int64_t> graph;
213217
int64_t *sample_keys = new int64_t[q.len];
214218
std::string key_str;
215-
cudaMemcpy(sample_keys, q.key, q.len * sizeof(int64_t),
219+
cudaMemcpy(sample_keys, q.src_nodes, q.len * sizeof(int64_t),
216220
cudaMemcpyDeviceToHost);
217221
int64_t *res = new int64_t[sample_size * key_size];
218222
cudaMemcpy(res, val, sample_size * key_size * sizeof(int64_t),

paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,35 @@
2323
#ifdef PADDLE_WITH_HETERPS
2424
namespace paddle {
2525
namespace framework {
26+
enum GraphTableType { EDGE_TABLE, FEATURE_TABLE };
2627
class GpuPsGraphTable : public HeterComm<uint64_t, int64_t, int> {
2728
public:
28-
GpuPsGraphTable(std::shared_ptr<HeterPsResource> resource, int topo_aware)
29+
int get_table_offset(int gpu_id, GraphTableType type, int idx) {
30+
int type_id = type;
31+
return gpu_id * (graph_table_num_ + feature_table_num_) +
32+
type_id * graph_table_num_ + idx;
33+
}
34+
GpuPsGraphTable(std::shared_ptr<HeterPsResource> resource, int topo_aware,
35+
int graph_table_num, int feature_table_num)
2936
: HeterComm<uint64_t, int64_t, int>(1, resource) {
3037
load_factor_ = 0.25;
3138
rw_lock.reset(new pthread_rwlock_t());
39+
this->graph_table_num_ = graph_table_num;
40+
this->feature_table_num_ = feature_table_num;
3241
gpu_num = resource_->total_device();
3342
memset(global_device_map, -1, sizeof(global_device_map));
43+
for (auto &table : tables_) {
44+
delete table;
45+
table = NULL;
46+
}
47+
tables_ = std::vector<Table *>(
48+
gpu_num * (graph_table_num + feature_table_num), NULL);
49+
sample_status = std::vector<int *>(gpu_num * graph_table_num, NULL);
3450
for (int i = 0; i < gpu_num; i++) {
35-
gpu_graph_list.push_back(GpuPsCommGraph());
3651
global_device_map[resource_->dev_id(i)] = i;
37-
sample_status.push_back(NULL);
38-
tables_.push_back(NULL);
52+
for (int j = 0; j < graph_table_num; j++) {
53+
gpu_graph_list_.push_back(GpuPsCommGraph());
54+
}
3955
}
4056
cpu_table_status = -1;
4157
if (topo_aware) {
@@ -89,21 +105,23 @@ class GpuPsGraphTable : public HeterComm<uint64_t, int64_t, int> {
89105
// end_graph_sampling();
90106
// }
91107
}
92-
void build_graph_on_single_gpu(GpuPsCommGraph &g, int gpu_id);
93-
void clear_graph_info(int gpu_id);
94-
void build_graph_from_cpu(std::vector<GpuPsCommGraph> &cpu_node_list);
108+
void build_graph_on_single_gpu(GpuPsCommGraph &g, int gpu_id, int idx);
109+
void clear_graph_info(int gpu_id, int index);
110+
void clear_graph_info(int index);
111+
void build_graph_from_cpu(std::vector<GpuPsCommGraph> &cpu_node_list,
112+
int idx);
95113
NodeQueryResult graph_node_sample(int gpu_id, int sample_size);
96114
NeighborSampleResult graph_neighbor_sample_v3(NeighborSampleQuery q,
97115
bool cpu_switch);
98-
NeighborSampleResult graph_neighbor_sample(int gpu_id, int64_t *key,
116+
NeighborSampleResult graph_neighbor_sample(int gpu_id, int idx, int64_t *key,
99117
int sample_size, int len);
100-
NeighborSampleResult graph_neighbor_sample_v2(int gpu_id, int64_t *key,
101-
int sample_size, int len,
102-
bool cpu_query_switch);
118+
NeighborSampleResult graph_neighbor_sample_v2(int gpu_id, int idx,
119+
int64_t *key, int sample_size,
120+
int len, bool cpu_query_switch);
103121
void init_sample_status();
104122
void free_sample_status();
105-
NodeQueryResult query_node_list(int gpu_id, int start, int query_size);
106-
void clear_graph_info();
123+
NodeQueryResult query_node_list(int gpu_id, int idx, int start,
124+
int query_size);
107125
void display_sample_res(void *key, void *val, int len, int sample_len);
108126
void move_neighbor_sample_result_to_source_gpu(int gpu_id, int gpu_num,
109127
int sample_size, int *h_left,
@@ -112,12 +130,13 @@ class GpuPsGraphTable : public HeterComm<uint64_t, int64_t, int> {
112130
int *actual_sample_size);
113131
int init_cpu_table(const paddle::distributed::GraphParameter &graph);
114132
int gpu_num;
115-
std::vector<GpuPsCommGraph> gpu_graph_list;
133+
int graph_table_num_, feature_table_num_;
134+
std::vector<GpuPsCommGraph> gpu_graph_list_;
116135
int global_device_map[32];
117136
std::vector<int *> sample_status;
118137
const int parallel_sample_size = 1;
119138
const int dim_y = 256;
120-
std::shared_ptr<paddle::distributed::GraphTable> cpu_graph_table;
139+
std::shared_ptr<paddle::distributed::GraphTable> cpu_graph_table_;
121140
std::shared_ptr<pthread_rwlock_t> rw_lock;
122141
mutable std::mutex mutex_;
123142
std::condition_variable cv_;

0 commit comments

Comments
 (0)