@@ -35,90 +35,99 @@ int GraphBrpcClient::get_server_index_by_id(uint64_t id) {
3535 return id % shard_num / shard_per_server;
3636}
3737// char* &buffer,int &actual_size
38- std::future<int32_t > GraphBrpcClient::batch_sample (uint32_t table_id,
39- std::vector<uint64_t > node_ids, int sample_size,
40- std::vector<std::vector<std::pair<uint64_t , float > > > &res) {
41-
38+ std::future<int32_t > GraphBrpcClient::batch_sample (
39+ uint32_t table_id, std::vector<uint64_t > node_ids, int sample_size,
40+ std::vector<std::vector<std::pair<uint64_t , float >>> &res) {
4241 std::vector<int > request2server;
4342 std::vector<int > server2request (server_size, -1 );
4443 res.clear ();
45- for (int query_idx = 0 ; query_idx < node_ids.size (); ++query_idx){
44+ for (int query_idx = 0 ; query_idx < node_ids.size (); ++query_idx) {
4645 int server_index = get_server_index_by_id (node_ids[query_idx]);
47- if (server2request[server_index] == -1 ){
46+ if (server2request[server_index] == -1 ) {
4847 server2request[server_index] = request2server.size ();
4948 request2server.push_back (server_index);
5049 }
51- // res.push_back(std::vector<GraphNode>());
50+ // res.push_back(std::vector<GraphNode>());
5251 res.push_back (std::vector<std::pair<uint64_t , float >>());
5352 }
5453 size_t request_call_num = request2server.size ();
55- std::vector<std::vector<uint64_t > > node_id_buckets (request_call_num);
56- std::vector<std::vector<int > > query_idx_buckets (request_call_num);
57- for (int query_idx = 0 ; query_idx < node_ids.size (); ++query_idx){
54+ std::vector<std::vector<uint64_t >> node_id_buckets (request_call_num);
55+ std::vector<std::vector<int >> query_idx_buckets (request_call_num);
56+ for (int query_idx = 0 ; query_idx < node_ids.size (); ++query_idx) {
5857 int server_index = get_server_index_by_id (node_ids[query_idx]);
5958 int request_idx = server2request[server_index];
6059 node_id_buckets[request_idx].push_back (node_ids[query_idx]);
6160 query_idx_buckets[request_idx].push_back (query_idx);
6261 }
6362
64- DownpourBrpcClosure *closure = new DownpourBrpcClosure (request_call_num, [&, node_id_buckets, query_idx_buckets, request_call_num](void *done) {
65- int ret = 0 ;
66- auto *closure = (DownpourBrpcClosure *)done;
67- int fail_num = 0 ;
68- for (int request_idx = 0 ; request_idx < request_call_num; ++request_idx){
69- if (closure->check_response (request_idx, PS_GRAPH_SAMPLE) != 0 ) {
70- ++fail_num;
71- } else {
72- VLOG (0 ) << " check sample response: "
73- << " " << closure->check_response (request_idx, PS_GRAPH_SAMPLE);
74- auto &res_io_buffer = closure->cntl (request_idx)->response_attachment ();
75- butil::IOBufBytesIterator io_buffer_itr (res_io_buffer);
76- size_t bytes_size = io_buffer_itr.bytes_left ();
77- char *buffer = new char [bytes_size];
78- io_buffer_itr.copy_and_forward ((void *)(buffer), bytes_size);
63+ DownpourBrpcClosure *closure = new DownpourBrpcClosure (
64+ request_call_num,
65+ [&, node_id_buckets, query_idx_buckets, request_call_num](void *done) {
66+ int ret = 0 ;
67+ auto *closure = (DownpourBrpcClosure *)done;
68+ int fail_num = 0 ;
69+ for (int request_idx = 0 ; request_idx < request_call_num;
70+ ++request_idx) {
71+ if (closure->check_response (request_idx, PS_GRAPH_SAMPLE) != 0 ) {
72+ ++fail_num;
73+ } else {
74+ auto &res_io_buffer =
75+ closure->cntl (request_idx)->response_attachment ();
76+ butil::IOBufBytesIterator io_buffer_itr (res_io_buffer);
77+ size_t bytes_size = io_buffer_itr.bytes_left ();
78+ // char buffer[bytes_size];
79+ std::unique_ptr<char []> buffer_wrapper (new char [bytes_size]);
80+ char *buffer = buffer_wrapper.get ();
81+ io_buffer_itr.copy_and_forward ((void *)(buffer), bytes_size);
7982
80- size_t node_num = *(size_t *)buffer;
81- int *actual_sizes = (int *)(buffer + sizeof (size_t ));
82- char *node_buffer = buffer + sizeof (size_t ) + sizeof (int ) * node_num;
83-
84- int offset = 0 ;
85- for (size_t node_idx = 0 ; node_idx < node_num; ++node_idx){
86- int query_idx = query_idx_buckets.at (request_idx).at (node_idx);
87- int actual_size = actual_sizes[node_idx];
88- int start = 0 ;
89- while (start < actual_size) {
90- res[query_idx].push_back ({*(uint64_t *)(node_buffer + offset + start),
91- *(float *)(node_buffer + offset + start + GraphNode::id_size)});
92- start += GraphNode::id_size + GraphNode::weight_size;
83+ size_t node_num = *(size_t *)buffer;
84+ int *actual_sizes = (int *)(buffer + sizeof (size_t ));
85+ char *node_buffer =
86+ buffer + sizeof (size_t ) + sizeof (int ) * node_num;
87+
88+ int offset = 0 ;
89+ for (size_t node_idx = 0 ; node_idx < node_num; ++node_idx) {
90+ int query_idx = query_idx_buckets.at (request_idx).at (node_idx);
91+ int actual_size = actual_sizes[node_idx];
92+ int start = 0 ;
93+ while (start < actual_size) {
94+ res[query_idx].push_back (
95+ {*(uint64_t *)(node_buffer + offset + start),
96+ *(float *)(node_buffer + offset + start +
97+ GraphNode::id_size)});
98+ start += GraphNode::id_size + GraphNode::weight_size;
99+ }
100+ offset += actual_size;
101+ }
102+ }
103+ if (fail_num == request_call_num) {
104+ ret = -1 ;
93105 }
94- offset += actual_size;
95106 }
96- }
97- if (fail_num == request_call_num){
98- ret = -1 ;
99- }
100- }
101- closure->set_promise_value (ret);
102- });
107+ closure->set_promise_value (ret);
108+ });
103109
104110 auto promise = std::make_shared<std::promise<int32_t >>();
105111 closure->add_promise (promise);
106112 std::future<int > fut = promise->get_future ();
107-
108- for (int request_idx = 0 ; request_idx < request_call_num; ++request_idx){
113+
114+ for (int request_idx = 0 ; request_idx < request_call_num; ++request_idx) {
109115 int server_index = request2server[request_idx];
110116 closure->request (request_idx)->set_cmd_id (PS_GRAPH_SAMPLE);
111117 closure->request (request_idx)->set_table_id (table_id);
112118 closure->request (request_idx)->set_client_id (_client_id);
113119 // std::string type_str = GraphNode::node_type_to_string(type);
114120 size_t node_num = node_id_buckets[request_idx].size ();
115-
116- closure->request (request_idx)->add_params ((char *)node_id_buckets[request_idx].data (), sizeof (uint64_t )*node_num);
117- closure->request (request_idx)->add_params ((char *)&sample_size, sizeof (int ));
121+
122+ closure->request (request_idx)
123+ ->add_params ((char *)node_id_buckets[request_idx].data (),
124+ sizeof (uint64_t ) * node_num);
125+ closure->request (request_idx)
126+ ->add_params ((char *)&sample_size, sizeof (int ));
118127 PsService_Stub rpc_stub (get_cmd_channel (server_index));
119128 closure->cntl (request_idx)->set_log_id (butil::gettimeofday_ms ());
120- rpc_stub.service (closure->cntl (request_idx), closure->request (request_idx), closure-> response (request_idx),
121- closure);
129+ rpc_stub.service (closure->cntl (request_idx), closure->request (request_idx),
130+ closure-> response (request_idx), closure );
122131 }
123132
124133 return fut;
@@ -133,12 +142,12 @@ std::future<int32_t> GraphBrpcClient::pull_graph_list(
133142 if (closure->check_response (0 , PS_PULL_GRAPH_LIST) != 0 ) {
134143 ret = -1 ;
135144 } else {
136- VLOG (0 ) << " check sample response: "
137- << " " << closure->check_response (0 , PS_PULL_GRAPH_LIST);
145+ // VLOG(0) << "check sample response: "
146+ // << " " << closure->check_response(0, PS_PULL_GRAPH_LIST);
138147 auto &res_io_buffer = closure->cntl (0 )->response_attachment ();
139148 butil::IOBufBytesIterator io_buffer_itr (res_io_buffer);
140149 size_t bytes_size = io_buffer_itr.bytes_left ();
141- char * buffer = new char [bytes_size];
150+ char buffer[bytes_size];
142151 io_buffer_itr.copy_and_forward ((void *)(buffer), bytes_size);
143152 int index = 0 ;
144153 while (index < bytes_size) {
0 commit comments