Skip to content

Commit 6c38fa0

Browse files
authored
Merge pull request #7 from seemingwang/develop
Merge
2 parents b08a36f + 832cab8 commit 6c38fa0

File tree

6 files changed

+160
-130
lines changed

6 files changed

+160
-130
lines changed

paddle/fluid/distributed/service/graph_brpc_client.cc

Lines changed: 65 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -35,90 +35,99 @@ int GraphBrpcClient::get_server_index_by_id(uint64_t id) {
3535
return id % shard_num / shard_per_server;
3636
}
3737
// char* &buffer,int &actual_size
38-
std::future<int32_t> GraphBrpcClient::batch_sample(uint32_t table_id,
39-
std::vector<uint64_t> node_ids, int sample_size,
40-
std::vector<std::vector<std::pair<uint64_t, float> > > &res) {
41-
38+
std::future<int32_t> GraphBrpcClient::batch_sample(
39+
uint32_t table_id, std::vector<uint64_t> node_ids, int sample_size,
40+
std::vector<std::vector<std::pair<uint64_t, float>>> &res) {
4241
std::vector<int> request2server;
4342
std::vector<int> server2request(server_size, -1);
4443
res.clear();
45-
for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx){
44+
for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
4645
int server_index = get_server_index_by_id(node_ids[query_idx]);
47-
if(server2request[server_index] == -1){
46+
if (server2request[server_index] == -1) {
4847
server2request[server_index] = request2server.size();
4948
request2server.push_back(server_index);
5049
}
51-
//res.push_back(std::vector<GraphNode>());
50+
// res.push_back(std::vector<GraphNode>());
5251
res.push_back(std::vector<std::pair<uint64_t, float>>());
5352
}
5453
size_t request_call_num = request2server.size();
55-
std::vector<std::vector<uint64_t> > node_id_buckets(request_call_num);
56-
std::vector<std::vector<int> > query_idx_buckets(request_call_num);
57-
for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx){
54+
std::vector<std::vector<uint64_t>> node_id_buckets(request_call_num);
55+
std::vector<std::vector<int>> query_idx_buckets(request_call_num);
56+
for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
5857
int server_index = get_server_index_by_id(node_ids[query_idx]);
5958
int request_idx = server2request[server_index];
6059
node_id_buckets[request_idx].push_back(node_ids[query_idx]);
6160
query_idx_buckets[request_idx].push_back(query_idx);
6261
}
6362

64-
DownpourBrpcClosure *closure = new DownpourBrpcClosure(request_call_num, [&, node_id_buckets, query_idx_buckets, request_call_num](void *done) {
65-
int ret = 0;
66-
auto *closure = (DownpourBrpcClosure *)done;
67-
int fail_num = 0;
68-
for (int request_idx = 0; request_idx < request_call_num; ++request_idx){
69-
if (closure->check_response(request_idx, PS_GRAPH_SAMPLE) != 0) {
70-
++fail_num;
71-
} else {
72-
VLOG(0) << "check sample response: "
73-
<< " " << closure->check_response(request_idx, PS_GRAPH_SAMPLE);
74-
auto &res_io_buffer = closure->cntl(request_idx)->response_attachment();
75-
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
76-
size_t bytes_size = io_buffer_itr.bytes_left();
77-
char *buffer = new char[bytes_size];
78-
io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size);
63+
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
64+
request_call_num,
65+
[&, node_id_buckets, query_idx_buckets, request_call_num](void *done) {
66+
int ret = 0;
67+
auto *closure = (DownpourBrpcClosure *)done;
68+
int fail_num = 0;
69+
for (int request_idx = 0; request_idx < request_call_num;
70+
++request_idx) {
71+
if (closure->check_response(request_idx, PS_GRAPH_SAMPLE) != 0) {
72+
++fail_num;
73+
} else {
74+
auto &res_io_buffer =
75+
closure->cntl(request_idx)->response_attachment();
76+
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
77+
size_t bytes_size = io_buffer_itr.bytes_left();
78+
// char buffer[bytes_size];
79+
std::unique_ptr<char[]> buffer_wrapper(new char[bytes_size]);
80+
char *buffer = buffer_wrapper.get();
81+
io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size);
7982

80-
size_t node_num = *(size_t *)buffer;
81-
int *actual_sizes = (int *)(buffer + sizeof(size_t));
82-
char *node_buffer = buffer + sizeof(size_t) + sizeof(int) * node_num;
83-
84-
int offset = 0;
85-
for (size_t node_idx = 0; node_idx < node_num; ++node_idx){
86-
int query_idx = query_idx_buckets.at(request_idx).at(node_idx);
87-
int actual_size = actual_sizes[node_idx];
88-
int start = 0;
89-
while (start < actual_size) {
90-
res[query_idx].push_back({*(uint64_t *)(node_buffer + offset + start),
91-
*(float *)(node_buffer + offset + start + GraphNode::id_size)});
92-
start += GraphNode::id_size + GraphNode::weight_size;
83+
size_t node_num = *(size_t *)buffer;
84+
int *actual_sizes = (int *)(buffer + sizeof(size_t));
85+
char *node_buffer =
86+
buffer + sizeof(size_t) + sizeof(int) * node_num;
87+
88+
int offset = 0;
89+
for (size_t node_idx = 0; node_idx < node_num; ++node_idx) {
90+
int query_idx = query_idx_buckets.at(request_idx).at(node_idx);
91+
int actual_size = actual_sizes[node_idx];
92+
int start = 0;
93+
while (start < actual_size) {
94+
res[query_idx].push_back(
95+
{*(uint64_t *)(node_buffer + offset + start),
96+
*(float *)(node_buffer + offset + start +
97+
GraphNode::id_size)});
98+
start += GraphNode::id_size + GraphNode::weight_size;
99+
}
100+
offset += actual_size;
101+
}
102+
}
103+
if (fail_num == request_call_num) {
104+
ret = -1;
93105
}
94-
offset += actual_size;
95106
}
96-
}
97-
if (fail_num == request_call_num){
98-
ret = -1;
99-
}
100-
}
101-
closure->set_promise_value(ret);
102-
});
107+
closure->set_promise_value(ret);
108+
});
103109

104110
auto promise = std::make_shared<std::promise<int32_t>>();
105111
closure->add_promise(promise);
106112
std::future<int> fut = promise->get_future();
107-
108-
for (int request_idx = 0; request_idx < request_call_num; ++request_idx){
113+
114+
for (int request_idx = 0; request_idx < request_call_num; ++request_idx) {
109115
int server_index = request2server[request_idx];
110116
closure->request(request_idx)->set_cmd_id(PS_GRAPH_SAMPLE);
111117
closure->request(request_idx)->set_table_id(table_id);
112118
closure->request(request_idx)->set_client_id(_client_id);
113119
// std::string type_str = GraphNode::node_type_to_string(type);
114120
size_t node_num = node_id_buckets[request_idx].size();
115-
116-
closure->request(request_idx)->add_params((char *)node_id_buckets[request_idx].data(), sizeof(uint64_t)*node_num);
117-
closure->request(request_idx)->add_params((char *)&sample_size, sizeof(int));
121+
122+
closure->request(request_idx)
123+
->add_params((char *)node_id_buckets[request_idx].data(),
124+
sizeof(uint64_t) * node_num);
125+
closure->request(request_idx)
126+
->add_params((char *)&sample_size, sizeof(int));
118127
PsService_Stub rpc_stub(get_cmd_channel(server_index));
119128
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
120-
rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx), closure->response(request_idx),
121-
closure);
129+
rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx),
130+
closure->response(request_idx), closure);
122131
}
123132

124133
return fut;
@@ -133,12 +142,12 @@ std::future<int32_t> GraphBrpcClient::pull_graph_list(
133142
if (closure->check_response(0, PS_PULL_GRAPH_LIST) != 0) {
134143
ret = -1;
135144
} else {
136-
VLOG(0) << "check sample response: "
137-
<< " " << closure->check_response(0, PS_PULL_GRAPH_LIST);
145+
// VLOG(0) << "check sample response: "
146+
// << " " << closure->check_response(0, PS_PULL_GRAPH_LIST);
138147
auto &res_io_buffer = closure->cntl(0)->response_attachment();
139148
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
140149
size_t bytes_size = io_buffer_itr.bytes_left();
141-
char *buffer = new char[bytes_size];
150+
char buffer[bytes_size];
142151
io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size);
143152
int index = 0;
144153
while (index < bytes_size) {

paddle/fluid/distributed/service/graph_brpc_server.cc

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -265,11 +265,10 @@ int32_t GraphBrpcService::pull_graph_list(Table *table,
265265
}
266266
int start = *(int *)(request.params(0).c_str());
267267
int size = *(int *)(request.params(1).c_str());
268-
std::vector<float> res_data;
269-
char *buffer;
268+
std::unique_ptr<char[]> buffer;
270269
int actual_size;
271270
table->pull_graph_list(start, size, buffer, actual_size);
272-
cntl->response_attachment().append(buffer, actual_size);
271+
cntl->response_attachment().append(buffer.get(), actual_size);
273272
return 0;
274273
}
275274
int32_t GraphBrpcService::graph_random_sample(Table *table,
@@ -287,19 +286,26 @@ int32_t GraphBrpcService::graph_random_sample(Table *table,
287286
uint64_t *node_data = (uint64_t *)(request.params(0).c_str());
288287
int sample_size = *(uint64_t *)(request.params(1).c_str());
289288

290-
std::vector<char*> buffers(node_num, nullptr);
289+
std::vector<std::unique_ptr<char[]>> buffers(node_num);
291290
std::vector<int> actual_sizes(node_num, 0);
292291
table->random_sample(node_data, sample_size, buffers, actual_sizes);
293292

294293
cntl->response_attachment().append(&node_num, sizeof(size_t));
295-
cntl->response_attachment().append(actual_sizes.data(), sizeof(int)*node_num);
296-
for (size_t idx = 0; idx < node_num; ++idx){
297-
cntl->response_attachment().append(buffers[idx], actual_sizes[idx]);
298-
if (buffers[idx] != nullptr){
299-
delete buffers[idx];
300-
buffers[idx] = nullptr;
301-
}
294+
cntl->response_attachment().append(actual_sizes.data(),
295+
sizeof(int) * node_num);
296+
for (size_t idx = 0; idx < node_num; ++idx) {
297+
cntl->response_attachment().append(buffers[idx].get(), actual_sizes[idx]);
298+
// if (buffers[idx] != nullptr){
299+
// delete buffers[idx];
300+
// buffers[idx] = nullptr;
301+
// }
302302
}
303+
// =======
304+
// std::unique_ptr<char[]> buffer;
305+
// int actual_size;
306+
// table->random_sample(node_id, sample_size, buffer, actual_size);
307+
// cntl->response_attachment().append(buffer.get(), actual_size);
308+
// >>>>>>> Stashed changes
303309
return 0;
304310
}
305311

paddle/fluid/distributed/table/common_graph_table.cc

Lines changed: 47 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ size_t GraphShard::get_size() {
6464
return res;
6565
}
6666

67-
std::list<GraphNode *>::iterator GraphShard::add_node(uint64_t id, std::string feature) {
67+
std::list<GraphNode *>::iterator GraphShard::add_node(uint64_t id,
68+
std::string feature) {
6869
if (node_location.find(id) != node_location.end())
6970
return node_location.find(id)->second;
7071

@@ -89,14 +90,13 @@ GraphNode *GraphShard::find_node(uint64_t id) {
8990

9091
int32_t GraphTable::load(const std::string &path, const std::string &param) {
9192
auto cmd = paddle::string::split_string<std::string>(param, "|");
92-
std::set<std::string> cmd_set(cmd.begin(), cmd.end());
93+
std::set<std::string> cmd_set(cmd.begin(), cmd.end());
9394
bool reverse_edge = cmd_set.count(std::string("reverse"));
9495
bool load_edge = cmd_set.count(std::string("edge"));
95-
if(load_edge) {
96-
return this -> load_edges(path, reverse_edge);
97-
}
98-
else {
99-
return this -> load_nodes(path);
96+
if (load_edge) {
97+
return this->load_edges(path, reverse_edge);
98+
} else {
99+
return this->load_nodes(path);
100100
}
101101
}
102102

@@ -110,33 +110,28 @@ int32_t GraphTable::load_nodes(const std::string &path) {
110110
if (values.size() < 2) continue;
111111
auto id = std::stoull(values[1]);
112112

113-
114113
size_t shard_id = id % shard_num;
115114
if (shard_id >= shard_end || shard_id < shard_start) {
116115
VLOG(4) << "will not load " << id << " from " << path
117116
<< ", please check id distribution";
118117
continue;
119-
120118
}
121119

122120
std::string node_type = values[0];
123-
std::vector<std::string > feature;
121+
std::vector<std::string> feature;
124122
feature.push_back(node_type);
125-
for(size_t slice = 2; slice < values.size(); slice ++) {
123+
for (size_t slice = 2; slice < values.size(); slice++) {
126124
feature.push_back(values[slice]);
127125
}
128126
auto feat = paddle::string::join_strings(feature, '\t');
129127
size_t index = shard_id - shard_start;
130128
shards[index].add_node(id, feat);
131-
132129
}
133130
}
134131
return 0;
135132
}
136133

137-
138134
int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) {
139-
140135
auto paths = paddle::string::split_string<std::string>(path, ";");
141136
int count = 0;
142137

@@ -173,7 +168,7 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) {
173168
VLOG(0) << "Load Finished Total Edge Count " << count;
174169

175170
// Build Sampler j
176-
171+
177172
for (auto &shard : shards) {
178173
auto bucket = shard.get_bucket();
179174
for (int i = 0; i < bucket.size(); i++) {
@@ -200,46 +195,49 @@ GraphNode *GraphTable::find_node(uint64_t id) {
200195
uint32_t GraphTable::get_thread_pool_index(uint64_t node_id) {
201196
return node_id % shard_num_per_table % task_pool_size_;
202197
}
203-
int GraphTable::random_sample(uint64_t* node_ids, int sample_size,
204-
std::vector<char*>& buffers, std::vector<int> &actual_sizes) {
198+
int GraphTable::random_sample(uint64_t *node_ids, int sample_size,
199+
std::vector<std::unique_ptr<char[]>> &buffers,
200+
std::vector<int> &actual_sizes) {
205201
size_t node_num = buffers.size();
206202
std::vector<std::future<int>> tasks;
207-
for (size_t idx = 0; idx < node_num; ++idx){
203+
for (size_t idx = 0; idx < node_num; ++idx) {
208204
uint64_t node_id = node_ids[idx];
209-
char* & buffer = buffers[idx];
210-
int& actual_size = actual_sizes[idx];
211-
tasks.push_back(_shards_task_pool[get_thread_pool_index(node_id)]
212-
->enqueue([&]() -> int {
213-
GraphNode *node = find_node(node_id);
214-
if (node == NULL) {
215-
actual_size = 0;
205+
std::unique_ptr<char[]> &buffer = buffers[idx];
206+
int &actual_size = actual_sizes[idx];
207+
tasks.push_back(_shards_task_pool[get_thread_pool_index(node_id)]->enqueue(
208+
[&]() -> int {
209+
GraphNode *node = find_node(node_id);
210+
if (node == NULL) {
211+
actual_size = 0;
212+
return 0;
213+
}
214+
std::vector<GraphEdge *> res = node->sample_k(sample_size);
215+
actual_size =
216+
res.size() * (GraphNode::id_size + GraphNode::weight_size);
217+
int offset = 0;
218+
uint64_t id;
219+
float weight;
220+
char *buffer_addr = new char[actual_size];
221+
buffer.reset(buffer_addr);
222+
for (auto &x : res) {
223+
id = x->get_id();
224+
weight = x->get_weight();
225+
memcpy(buffer_addr + offset, &id, GraphNode::id_size);
226+
offset += GraphNode::id_size;
227+
memcpy(buffer_addr + offset, &weight, GraphNode::weight_size);
228+
offset += GraphNode::weight_size;
229+
return 0;
230+
}
216231
return 0;
217-
}
218-
std::vector<GraphEdge *> res = node->sample_k(sample_size);
219-
std::vector<GraphNode> node_list;
220-
actual_size =
221-
res.size() * (GraphNode::id_size + GraphNode::weight_size);
222-
buffer = new char[actual_size];
223-
int offset = 0;
224-
uint64_t id;
225-
float weight;
226-
for (auto &x : res) {
227-
id = x->get_id();
228-
weight = x->get_weight();
229-
memcpy(buffer + offset, &id, GraphNode::id_size);
230-
offset += GraphNode::id_size;
231-
memcpy(buffer + offset, &weight, GraphNode::weight_size);
232-
offset += GraphNode::weight_size;
233-
}
234-
return 0;
235-
}));
232+
}));
236233
}
237-
for (size_t idx = 0; idx < node_num; ++idx){
234+
for (size_t idx = 0; idx < node_num; ++idx) {
238235
tasks[idx].get();
239236
}
240237
return 0;
241238
}
242-
int32_t GraphTable::pull_graph_list(int start, int total_size, char *&buffer,
239+
int32_t GraphTable::pull_graph_list(int start, int total_size,
240+
std::unique_ptr<char[]> &buffer,
243241
int &actual_size) {
244242
if (start < 0) start = 0;
245243
int size = 0, cur_size;
@@ -283,11 +281,12 @@ int32_t GraphTable::pull_graph_list(int start, int total_size, char *&buffer,
283281
size += res.back()[j]->get_size();
284282
}
285283
}
286-
buffer = new char[size];
284+
char *buffer_addr = new char[size];
285+
buffer.reset(buffer_addr);
287286
int index = 0;
288287
for (size_t i = 0; i < res.size(); i++) {
289288
for (size_t j = 0; j < res[i].size(); j++) {
290-
res[i][j]->to_buffer(buffer + index);
289+
res[i][j]->to_buffer(buffer_addr + index);
291290
index += res[i][j]->get_size();
292291
}
293292
}
@@ -321,4 +320,3 @@ int32_t GraphTable::initialize() {
321320
}
322321
}
323322
};
324-

0 commit comments

Comments
 (0)