2222#include < string>
2323#include < thread> // NOLINT
2424#include < vector>
25+ #include < unordered_map>
2526#include " google/protobuf/text_format.h"
2627
2728#include " gtest/gtest.h"
@@ -46,7 +47,7 @@ class GraphPyService {
4647 std::vector<int > keys;
4748 std::vector<std::string> server_list, port_list, host_sign_list;
4849 int server_size, shard_num, rank, client_id;
49- uint32_t table_id ;
50+ std::unordered_map<std::string, uint32_t > table_id_map ;
5051 std::thread *server_thread, *client_thread;
5152
5253 std::shared_ptr<paddle::distributed::PSServer> pserver_ptr;
@@ -67,7 +68,7 @@ class GraphPyService {
6768 int get_shard_num () { return shard_num; }
6869 void set_shard_num (int shard_num) { this ->shard_num = shard_num; }
6970 void GetDownpourSparseTableProto (
70- ::paddle::distributed::TableParameter* sparse_table_proto) {
71+ ::paddle::distributed::TableParameter* sparse_table_proto, uint32_t table_id ) {
7172 sparse_table_proto->set_table_id (table_id);
7273 sparse_table_proto->set_table_class (" GraphTable" );
7374 sparse_table_proto->set_shard_num (shard_num);
@@ -96,10 +97,14 @@ class GraphPyService {
9697 server_service_proto->set_start_server_port (0 );
9798 server_service_proto->set_server_thread_num (12 );
9899
99- ::paddle::distributed::TableParameter* sparse_table_proto =
100- downpour_server_proto->add_downpour_table_param ();
101- GetDownpourSparseTableProto (sparse_table_proto);
100+ for (auto & tuple : this -> table_id_map) {
101+ ::paddle::distributed::TableParameter* sparse_table_proto =
102+ downpour_server_proto->add_downpour_table_param ();
103+ GetDownpourSparseTableProto (sparse_table_proto, tuple.second );
104+ }
105+
102106 return server_fleet_desc;
107+
103108 }
104109
105110 ::paddle::distributed::PSParameter GetWorkerProto () {
@@ -111,9 +116,11 @@ class GraphPyService {
111116 ::paddle::distributed::DownpourWorkerParameter* downpour_worker_proto =
112117 worker_proto->mutable_downpour_worker_param ();
113118
114- ::paddle::distributed::TableParameter* worker_sparse_table_proto =
115- downpour_worker_proto->add_downpour_table_param ();
116- GetDownpourSparseTableProto (worker_sparse_table_proto);
119+ for (auto & tuple : this -> table_id_map) {
120+ ::paddle::distributed::TableParameter* worker_sparse_table_proto =
121+ downpour_worker_proto->add_downpour_table_param ();
122+ GetDownpourSparseTableProto (worker_sparse_table_proto, tuple.second );
123+ }
117124
118125 ::paddle::distributed::ServerParameter* server_proto =
119126 worker_fleet_desc.mutable_server_param ();
@@ -127,34 +134,59 @@ class GraphPyService {
127134 server_service_proto->set_start_server_port (0 );
128135 server_service_proto->set_server_thread_num (12 );
129136
130- ::paddle::distributed::TableParameter* server_sparse_table_proto =
131- downpour_server_proto->add_downpour_table_param ();
132- GetDownpourSparseTableProto (server_sparse_table_proto);
137+ for (auto & tuple : this -> table_id_map) {
138+ ::paddle::distributed::TableParameter* sparse_table_proto =
139+ downpour_server_proto->add_downpour_table_param ();
140+ GetDownpourSparseTableProto (sparse_table_proto, tuple.second );
141+ }
133142
134143 return worker_fleet_desc;
135144 }
136145 void set_server_size (int server_size) { this ->server_size = server_size; }
137146 int get_server_size (int server_size) { return server_size; }
138147 std::vector<std::string> split (std::string& str, const char pattern);
139148
140- void load_file (std::string filepath) {
141- auto status =
142- get_ps_client ()->load (table_id, std::string (filepath), std::string (" " ));
143- status.wait ();
149+ void load_edge_file (std::string name, std::string filepath, bool reverse) {
150+ std::string params = " edge" ;
151+ if (reverse) {
152+ params += " |reverse" ;
153+ }
154+ if (this -> table_id_map.count (name)) {
155+ uint32_t table_id = this -> table_id_map[name];
156+ auto status =
157+ get_ps_client ()->load (table_id, std::string (filepath), params);
158+ status.wait ();
159+ }
160+ }
161+
162+ void load_node_file (std::string name, std::string filepath) {
163+ std::string params = " node" ;
164+ if (this -> table_id_map.count (name)) {
165+ uint32_t table_id = this -> table_id_map[name];
166+ auto status =
167+ get_ps_client ()->load (table_id, std::string (filepath), params);
168+ status.wait ();
169+ }
144170 }
145171
146- std::vector<GraphNode> sample_k (uint64_t node_id, int sample_size) {
172+ std::vector<GraphNode> sample_k (std::string name, uint64_t node_id, int sample_size) {
147173 std::vector<GraphNode> v;
148- auto status = worker_ptr->sample (table_id, node_id, sample_size, v);
149- status.wait ();
174+ if (this -> table_id_map.count (name)) {
175+ uint32_t table_id = this -> table_id_map[name];
176+ auto status = worker_ptr->sample (table_id, node_id, sample_size, v);
177+ status.wait ();
178+ }
150179 return v;
151180 }
152- std::vector<GraphNode> pull_graph_list (int server_index, int start,
181+ std::vector<GraphNode> pull_graph_list (std::string name, int server_index, int start,
153182 int size) {
154183 std::vector<GraphNode> res;
155- auto status =
156- worker_ptr->pull_graph_list (table_id, server_index, start, size, res);
157- status.wait ();
184+ if (this -> table_id_map.count (name)) {
185+ uint32_t table_id = this -> table_id_map[name];
186+ auto status =
187+ worker_ptr->pull_graph_list (table_id, server_index, start, size, res);
188+ status.wait ();
189+ }
158190 return res;
159191 }
160192 void start_server (std::string ip, uint32_t port) {
@@ -197,7 +229,7 @@ class GraphPyService {
197229 worker_ptr->configure (worker_proto, dense_regions, _ps_env, client_id);
198230 }
199231 void set_up (std::string ips_str, int shard_num, int rank, int client_id,
200- uint32_t table_id );
232+ std::vector<std::string> edge_types );
201233 void set_keys (std::vector<int > keys) { // just for test
202234 this ->keys = keys;
203235 }
0 commit comments