@@ -479,6 +479,102 @@ std::future<int32_t> GraphBrpcClient::pull_graph_list(
479479 closure);
480480 return fut;
481481}
482+
483+ std::future<int32_t > GraphBrpcClient::set_node_feat (
484+ const uint32_t &table_id, const std::vector<uint64_t > &node_ids,
485+ const std::vector<std::string> &feature_names,
486+ const std::vector<std::vector<std::string>> &features) {
487+ std::vector<int > request2server;
488+ std::vector<int > server2request (server_size, -1 );
489+ for (int query_idx = 0 ; query_idx < node_ids.size (); ++query_idx) {
490+ int server_index = get_server_index_by_id (node_ids[query_idx]);
491+ if (server2request[server_index] == -1 ) {
492+ server2request[server_index] = request2server.size ();
493+ request2server.push_back (server_index);
494+ }
495+ }
496+ size_t request_call_num = request2server.size ();
497+ std::vector<std::vector<uint64_t >> node_id_buckets (request_call_num);
498+ std::vector<std::vector<int >> query_idx_buckets (request_call_num);
499+ std::vector<std::vector<std::vector<std::string>>> features_idx_buckets (
500+ request_call_num);
501+ for (int query_idx = 0 ; query_idx < node_ids.size (); ++query_idx) {
502+ int server_index = get_server_index_by_id (node_ids[query_idx]);
503+ int request_idx = server2request[server_index];
504+ node_id_buckets[request_idx].push_back (node_ids[query_idx]);
505+ query_idx_buckets[request_idx].push_back (query_idx);
506+ if (features_idx_buckets[request_idx].size () == 0 ) {
507+ features_idx_buckets[request_idx].resize (feature_names.size ());
508+ }
509+ for (int feat_idx = 0 ; feat_idx < feature_names.size (); ++feat_idx) {
510+ features_idx_buckets[request_idx][feat_idx].push_back (
511+ features[feat_idx][query_idx]);
512+ }
513+ }
514+
515+ DownpourBrpcClosure *closure = new DownpourBrpcClosure (
516+ request_call_num,
517+ [&, node_id_buckets, query_idx_buckets, request_call_num](void *done) {
518+ int ret = 0 ;
519+ auto *closure = (DownpourBrpcClosure *)done;
520+ size_t fail_num = 0 ;
521+ for (int request_idx = 0 ; request_idx < request_call_num;
522+ ++request_idx) {
523+ if (closure->check_response (request_idx, PS_GRAPH_SET_NODE_FEAT) !=
524+ 0 ) {
525+ ++fail_num;
526+ }
527+ if (fail_num == request_call_num) {
528+ ret = -1 ;
529+ }
530+ }
531+ closure->set_promise_value (ret);
532+ });
533+
534+ auto promise = std::make_shared<std::promise<int32_t >>();
535+ closure->add_promise (promise);
536+ std::future<int > fut = promise->get_future ();
537+
538+ for (int request_idx = 0 ; request_idx < request_call_num; ++request_idx) {
539+ int server_index = request2server[request_idx];
540+ closure->request (request_idx)->set_cmd_id (PS_GRAPH_SET_NODE_FEAT);
541+ closure->request (request_idx)->set_table_id (table_id);
542+ closure->request (request_idx)->set_client_id (_client_id);
543+ size_t node_num = node_id_buckets[request_idx].size ();
544+
545+ closure->request (request_idx)
546+ ->add_params ((char *)node_id_buckets[request_idx].data (),
547+ sizeof (uint64_t ) * node_num);
548+ std::string joint_feature_name =
549+ paddle::string::join_strings (feature_names, ' \t ' );
550+ closure->request (request_idx)
551+ ->add_params (joint_feature_name.c_str (), joint_feature_name.size ());
552+
553+ // set features
554+ std::string set_feature = " " ;
555+ for (size_t feat_idx = 0 ; feat_idx < feature_names.size (); ++feat_idx) {
556+ for (size_t node_idx = 0 ; node_idx < node_num; ++node_idx) {
557+ size_t feat_len =
558+ features_idx_buckets[request_idx][feat_idx][node_idx].size ();
559+ set_feature.append ((char *)&feat_len, sizeof (size_t ));
560+ set_feature.append (
561+ features_idx_buckets[request_idx][feat_idx][node_idx].data (),
562+ feat_len);
563+ }
564+ }
565+ closure->request (request_idx)
566+ ->add_params (set_feature.c_str (), set_feature.size ());
567+
568+ GraphPsService_Stub rpc_stub =
569+ getServiceStub (get_cmd_channel (server_index));
570+ closure->cntl (request_idx)->set_log_id (butil::gettimeofday_ms ());
571+ rpc_stub.service (closure->cntl (request_idx), closure->request (request_idx),
572+ closure->response (request_idx), closure);
573+ }
574+
575+ return fut;
576+ }
577+
482578int32_t GraphBrpcClient::initialize () {
483579 // set_shard_num(_config.shard_num());
484580 BrpcPsClient::initialize ();
0 commit comments