Skip to content

Commit 1e843d1

Browse files
Merge pull request #25 from PaddlePaddle/develop
update
2 parents 7addd79 + aefec22 commit 1e843d1

File tree

387 files changed

+31935
-3213
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

387 files changed

+31935
-3213
lines changed

.pre-commit-config.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,7 @@ repos:
4949
entry: python ./tools/codestyle/copyright.hook
5050
language: system
5151
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py|sh)$
52-
exclude: (?!.*third_party)^.*$ | (?!.*book)^.*$
52+
exclude: |
53+
(?x)^(
54+
paddle/utils/.*
55+
)$

cmake/cupti.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ find_path(CUPTI_INCLUDE_DIR cupti.h
99
$ENV{CUPTI_ROOT} $ENV{CUPTI_ROOT}/include
1010
${CUDA_TOOLKIT_ROOT_DIR}/extras/CUPTI/include
1111
${CUDA_TOOLKIT_ROOT_DIR}/targets/x86_64-linux/include
12+
${CUDA_TOOLKIT_ROOT_DIR}/targets/aarch64-linux/include
1213
NO_DEFAULT_PATH
1314
)
1415

paddle/fluid/distributed/service/graph_brpc_client.cc

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,102 @@ std::future<int32_t> GraphBrpcClient::pull_graph_list(
479479
closure);
480480
return fut;
481481
}
482+
483+
std::future<int32_t> GraphBrpcClient::set_node_feat(
484+
const uint32_t &table_id, const std::vector<uint64_t> &node_ids,
485+
const std::vector<std::string> &feature_names,
486+
const std::vector<std::vector<std::string>> &features) {
487+
std::vector<int> request2server;
488+
std::vector<int> server2request(server_size, -1);
489+
for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
490+
int server_index = get_server_index_by_id(node_ids[query_idx]);
491+
if (server2request[server_index] == -1) {
492+
server2request[server_index] = request2server.size();
493+
request2server.push_back(server_index);
494+
}
495+
}
496+
size_t request_call_num = request2server.size();
497+
std::vector<std::vector<uint64_t>> node_id_buckets(request_call_num);
498+
std::vector<std::vector<int>> query_idx_buckets(request_call_num);
499+
std::vector<std::vector<std::vector<std::string>>> features_idx_buckets(
500+
request_call_num);
501+
for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
502+
int server_index = get_server_index_by_id(node_ids[query_idx]);
503+
int request_idx = server2request[server_index];
504+
node_id_buckets[request_idx].push_back(node_ids[query_idx]);
505+
query_idx_buckets[request_idx].push_back(query_idx);
506+
if (features_idx_buckets[request_idx].size() == 0) {
507+
features_idx_buckets[request_idx].resize(feature_names.size());
508+
}
509+
for (int feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
510+
features_idx_buckets[request_idx][feat_idx].push_back(
511+
features[feat_idx][query_idx]);
512+
}
513+
}
514+
515+
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
516+
request_call_num,
517+
[&, node_id_buckets, query_idx_buckets, request_call_num](void *done) {
518+
int ret = 0;
519+
auto *closure = (DownpourBrpcClosure *)done;
520+
size_t fail_num = 0;
521+
for (int request_idx = 0; request_idx < request_call_num;
522+
++request_idx) {
523+
if (closure->check_response(request_idx, PS_GRAPH_SET_NODE_FEAT) !=
524+
0) {
525+
++fail_num;
526+
}
527+
if (fail_num == request_call_num) {
528+
ret = -1;
529+
}
530+
}
531+
closure->set_promise_value(ret);
532+
});
533+
534+
auto promise = std::make_shared<std::promise<int32_t>>();
535+
closure->add_promise(promise);
536+
std::future<int> fut = promise->get_future();
537+
538+
for (int request_idx = 0; request_idx < request_call_num; ++request_idx) {
539+
int server_index = request2server[request_idx];
540+
closure->request(request_idx)->set_cmd_id(PS_GRAPH_SET_NODE_FEAT);
541+
closure->request(request_idx)->set_table_id(table_id);
542+
closure->request(request_idx)->set_client_id(_client_id);
543+
size_t node_num = node_id_buckets[request_idx].size();
544+
545+
closure->request(request_idx)
546+
->add_params((char *)node_id_buckets[request_idx].data(),
547+
sizeof(uint64_t) * node_num);
548+
std::string joint_feature_name =
549+
paddle::string::join_strings(feature_names, '\t');
550+
closure->request(request_idx)
551+
->add_params(joint_feature_name.c_str(), joint_feature_name.size());
552+
553+
// set features
554+
std::string set_feature = "";
555+
for (size_t feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
556+
for (size_t node_idx = 0; node_idx < node_num; ++node_idx) {
557+
size_t feat_len =
558+
features_idx_buckets[request_idx][feat_idx][node_idx].size();
559+
set_feature.append((char *)&feat_len, sizeof(size_t));
560+
set_feature.append(
561+
features_idx_buckets[request_idx][feat_idx][node_idx].data(),
562+
feat_len);
563+
}
564+
}
565+
closure->request(request_idx)
566+
->add_params(set_feature.c_str(), set_feature.size());
567+
568+
GraphPsService_Stub rpc_stub =
569+
getServiceStub(get_cmd_channel(server_index));
570+
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
571+
rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx),
572+
closure->response(request_idx), closure);
573+
}
574+
575+
return fut;
576+
}
577+
482578
int32_t GraphBrpcClient::initialize() {
483579
// set_shard_num(_config.shard_num());
484580
BrpcPsClient::initialize();

paddle/fluid/distributed/service/graph_brpc_client.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,11 @@ class GraphBrpcClient : public BrpcPsClient {
7979
const std::vector<std::string>& feature_names,
8080
std::vector<std::vector<std::string>>& res);
8181

82+
virtual std::future<int32_t> set_node_feat(
83+
const uint32_t& table_id, const std::vector<uint64_t>& node_ids,
84+
const std::vector<std::string>& feature_names,
85+
const std::vector<std::vector<std::string>>& features);
86+
8287
virtual std::future<int32_t> clear_nodes(uint32_t table_id);
8388
virtual std::future<int32_t> add_graph_node(
8489
uint32_t table_id, std::vector<uint64_t>& node_id_list,

paddle/fluid/distributed/service/graph_brpc_server.cc

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "paddle/fluid/distributed/service/brpc_ps_server.h"
1717

1818
#include <thread> // NOLINT
19+
#include <utility>
1920
#include "butil/endpoint.h"
2021
#include "iomanip"
2122
#include "paddle/fluid/distributed/service/brpc_ps_client.h"
@@ -157,6 +158,8 @@ int32_t GraphBrpcService::initialize() {
157158
&GraphBrpcService::add_graph_node;
158159
_service_handler_map[PS_GRAPH_REMOVE_GRAPH_NODE] =
159160
&GraphBrpcService::remove_graph_node;
161+
_service_handler_map[PS_GRAPH_SET_NODE_FEAT] =
162+
&GraphBrpcService::graph_set_node_feat;
160163
// shard初始化,server启动后才可从env获取到server_list的shard信息
161164
initialize_shard_info();
162165

@@ -400,5 +403,44 @@ int32_t GraphBrpcService::graph_get_node_feat(Table *table,
400403

401404
return 0;
402405
}
406+
407+
int32_t GraphBrpcService::graph_set_node_feat(Table *table,
408+
const PsRequestMessage &request,
409+
PsResponseMessage &response,
410+
brpc::Controller *cntl) {
411+
CHECK_TABLE_EXIST(table, request, response)
412+
if (request.params_size() < 3) {
413+
set_response_code(
414+
response, -1,
415+
"graph_set_node_feat request requires at least 2 arguments");
416+
return 0;
417+
}
418+
size_t node_num = request.params(0).size() / sizeof(uint64_t);
419+
uint64_t *node_data = (uint64_t *)(request.params(0).c_str());
420+
std::vector<uint64_t> node_ids(node_data, node_data + node_num);
421+
422+
std::vector<std::string> feature_names =
423+
paddle::string::split_string<std::string>(request.params(1), "\t");
424+
425+
std::vector<std::vector<std::string>> features(
426+
feature_names.size(), std::vector<std::string>(node_num));
427+
428+
const char *buffer = request.params(2).c_str();
429+
430+
for (size_t feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
431+
for (size_t node_idx = 0; node_idx < node_num; ++node_idx) {
432+
size_t feat_len = *(size_t *)(buffer);
433+
buffer += sizeof(size_t);
434+
auto feat = std::string(buffer, feat_len);
435+
features[feat_idx][node_idx] = feat;
436+
buffer += feat_len;
437+
}
438+
}
439+
440+
((GraphTable *)table)->set_node_feat(node_ids, feature_names, features);
441+
442+
return 0;
443+
}
444+
403445
} // namespace distributed
404446
} // namespace paddle

paddle/fluid/distributed/service/graph_brpc_server.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,13 @@ class GraphBrpcService : public PsBaseService {
8383
const PsRequestMessage &request,
8484
PsResponseMessage &response,
8585
brpc::Controller *cntl);
86+
8687
int32_t graph_get_node_feat(Table *table, const PsRequestMessage &request,
8788
PsResponseMessage &response,
8889
brpc::Controller *cntl);
90+
int32_t graph_set_node_feat(Table *table, const PsRequestMessage &request,
91+
PsResponseMessage &response,
92+
brpc::Controller *cntl);
8993
int32_t clear_nodes(Table *table, const PsRequestMessage &request,
9094
PsResponseMessage &response, brpc::Controller *cntl);
9195
int32_t add_graph_node(Table *table, const PsRequestMessage &request,

paddle/fluid/distributed/service/graph_py_service.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,19 @@ std::vector<std::vector<std::string>> GraphPyClient::get_node_feat(
330330
return v;
331331
}
332332

333+
void GraphPyClient::set_node_feat(
334+
std::string node_type, std::vector<uint64_t> node_ids,
335+
std::vector<std::string> feature_names,
336+
const std::vector<std::vector<std::string>> features) {
337+
if (this->table_id_map.count(node_type)) {
338+
uint32_t table_id = this->table_id_map[node_type];
339+
auto status =
340+
worker_ptr->set_node_feat(table_id, node_ids, feature_names, features);
341+
status.wait();
342+
}
343+
return;
344+
}
345+
333346
std::vector<FeatureNode> GraphPyClient::pull_graph_list(std::string name,
334347
int server_index,
335348
int start, int size,

paddle/fluid/distributed/service/graph_py_service.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,9 @@ class GraphPyClient : public GraphPyService {
155155
std::vector<std::vector<std::string>> get_node_feat(
156156
std::string node_type, std::vector<uint64_t> node_ids,
157157
std::vector<std::string> feature_names);
158+
void set_node_feat(std::string node_type, std::vector<uint64_t> node_ids,
159+
std::vector<std::string> feature_names,
160+
const std::vector<std::vector<std::string>> features);
158161
std::vector<FeatureNode> pull_graph_list(std::string name, int server_index,
159162
int start, int size, int step = 1);
160163
::paddle::distributed::PSParameter GetWorkerProto();

paddle/fluid/distributed/service/sendrecv.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ enum PsCmdID {
5555
PS_GRAPH_CLEAR = 34;
5656
PS_GRAPH_ADD_GRAPH_NODE = 35;
5757
PS_GRAPH_REMOVE_GRAPH_NODE = 36;
58+
PS_GRAPH_SET_NODE_FEAT = 37;
5859
}
5960

6061
message PsRequestMessage {

paddle/fluid/distributed/table/common_graph_table.cc

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,34 @@ int32_t GraphTable::get_node_feat(const std::vector<uint64_t> &node_ids,
469469
return 0;
470470
}
471471

472+
int32_t GraphTable::set_node_feat(
473+
const std::vector<uint64_t> &node_ids,
474+
const std::vector<std::string> &feature_names,
475+
const std::vector<std::vector<std::string>> &res) {
476+
size_t node_num = node_ids.size();
477+
std::vector<std::future<int>> tasks;
478+
for (size_t idx = 0; idx < node_num; ++idx) {
479+
uint64_t node_id = node_ids[idx];
480+
tasks.push_back(_shards_task_pool[get_thread_pool_index(node_id)]->enqueue(
481+
[&, idx, node_id]() -> int {
482+
size_t index = node_id % this->shard_num - this->shard_start;
483+
auto node = shards[index].add_feature_node(node_id);
484+
node->set_feature_size(this->feat_name.size());
485+
for (int feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
486+
const std::string &feature_name = feature_names[feat_idx];
487+
if (feat_id_map.find(feature_name) != feat_id_map.end()) {
488+
node->set_feature(feat_id_map[feature_name], res[feat_idx][idx]);
489+
}
490+
}
491+
return 0;
492+
}));
493+
}
494+
for (size_t idx = 0; idx < node_num; ++idx) {
495+
tasks[idx].get();
496+
}
497+
return 0;
498+
}
499+
472500
std::pair<int32_t, std::string> GraphTable::parse_feature(
473501
std::string feat_str) {
474502
// Return (feat_id, btyes) if name are in this->feat_name, else return (-1,

0 commit comments

Comments
 (0)