@@ -32,10 +32,25 @@ std::vector<std::string> GraphPyService::split(std::string& str,
3232 return res;
3333}
3434
35+
36+ void GraphPyService::add_table_feat_conf (std::string table_name,
37+ std::string feat_name,
38+ std::string feat_dtype,
39+ int32_t feat_shape) {
40+ if (this ->table_id_map .count (table_name)) {
41+ this ->table_feat_conf_table_name .push_back (table_name);
42+ this ->table_feat_conf_feat_name .push_back (feat_name);
43+ this ->table_feat_conf_feat_dtype .push_back (feat_dtype);
44+ this ->table_feat_conf_feat_shape .push_back (feat_shape);
45+ }
46+ }
47+
48+
3549void GraphPyService::set_up (std::string ips_str, int shard_num,
3650 std::vector<std::string> node_types,
3751 std::vector<std::string> edge_types) {
3852 set_shard_num (shard_num);
53+ set_num_node_types (node_types.size ());
3954 // set_client_Id(client_id);
4055 // set_rank(rank);
4156
@@ -121,7 +136,31 @@ ::paddle::distributed::PSParameter GraphPyServer::GetServerProto() {
121136 VLOG (0 ) << " make a new table " << tuple.second ;
122137 ::paddle::distributed::TableParameter* sparse_table_proto =
123138 downpour_server_proto->add_downpour_table_param ();
124- GetDownpourSparseTableProto (sparse_table_proto, tuple.second );
139+ std::vector<std::string > feat_name;
140+ std::vector<std::string > feat_dtype;
141+ std::vector<int32_t > feat_shape;
142+ for (size_t i=0 ; i<this ->table_feat_conf_table_name .size (); i++) {
143+ if (tuple.first == table_feat_conf_table_name[i]) {
144+ feat_name.push_back (table_feat_conf_feat_name[i]);
145+ feat_dtype.push_back (table_feat_conf_feat_dtype[i]);
146+ feat_shape.push_back (table_feat_conf_feat_shape[i]);
147+ }
148+ }
149+ std::string table_type;
150+ if (tuple.second < this ->num_node_types ) {
151+ table_type = " node" ;
152+ }
153+ else {
154+ table_type = " edge" ;
155+ }
156+
157+ GetDownpourSparseTableProto (sparse_table_proto,
158+ tuple.second ,
159+ tuple.first ,
160+ table_type,
161+ feat_name,
162+ feat_dtype,
163+ feat_shape);
125164 }
126165
127166 return server_fleet_desc;
@@ -137,11 +176,38 @@ ::paddle::distributed::PSParameter GraphPyClient::GetWorkerProto() {
137176 worker_proto->mutable_downpour_worker_param ();
138177
139178 for (auto & tuple : this ->table_id_map ) {
179+ VLOG (0 ) << " make a new table " << tuple.second ;
140180 ::paddle::distributed::TableParameter* worker_sparse_table_proto =
141181 downpour_worker_proto->add_downpour_table_param ();
142- GetDownpourSparseTableProto (worker_sparse_table_proto, tuple.second );
182+ std::vector<std::string > feat_name;
183+ std::vector<std::string > feat_dtype;
184+ std::vector<int32_t > feat_shape;
185+ for (size_t i=0 ; i<this ->table_feat_conf_table_name .size (); i++) {
186+ if (tuple.first == table_feat_conf_table_name[i]) {
187+ feat_name.push_back (table_feat_conf_feat_name[i]);
188+ feat_dtype.push_back (table_feat_conf_feat_dtype[i]);
189+ feat_shape.push_back (table_feat_conf_feat_shape[i]);
190+ }
191+ }
192+ std::string table_type;
193+ if (tuple.second < this ->num_node_types ) {
194+ table_type = " node" ;
195+ }
196+ else {
197+ table_type = " edge" ;
198+ }
199+
200+ GetDownpourSparseTableProto (worker_sparse_table_proto,
201+ tuple.second ,
202+ tuple.first ,
203+ table_type,
204+ feat_name,
205+ feat_dtype,
206+ feat_shape);
143207 }
144208
209+
210+
145211 ::paddle::distributed::ServerParameter* server_proto =
146212 worker_fleet_desc.mutable_server_param ();
147213 ::paddle::distributed::DownpourServerParameter* downpour_server_proto =
@@ -155,11 +221,38 @@ ::paddle::distributed::PSParameter GraphPyClient::GetWorkerProto() {
155221 server_service_proto->set_server_thread_num (12 );
156222
157223 for (auto & tuple : this ->table_id_map ) {
224+ VLOG (0 ) << " make a new table " << tuple.second ;
158225 ::paddle::distributed::TableParameter* sparse_table_proto =
159226 downpour_server_proto->add_downpour_table_param ();
160- GetDownpourSparseTableProto (sparse_table_proto, tuple.second );
227+ std::vector<std::string > feat_name;
228+ std::vector<std::string > feat_dtype;
229+ std::vector<int32_t > feat_shape;
230+ for (size_t i=0 ; i<this ->table_feat_conf_table_name .size (); i++) {
231+ if (tuple.first == table_feat_conf_table_name[i]) {
232+ feat_name.push_back (table_feat_conf_feat_name[i]);
233+ feat_dtype.push_back (table_feat_conf_feat_dtype[i]);
234+ feat_shape.push_back (table_feat_conf_feat_shape[i]);
235+ }
236+ }
237+ std::string table_type;
238+ if (tuple.second < this ->num_node_types ) {
239+ table_type = " node" ;
240+ }
241+ else {
242+ table_type = " edge" ;
243+ }
244+
245+ GetDownpourSparseTableProto (sparse_table_proto,
246+ tuple.second ,
247+ tuple.first ,
248+ table_type,
249+ feat_name,
250+ feat_dtype,
251+ feat_shape);
161252 }
162253
254+
255+
163256 return worker_fleet_desc;
164257}
165258void GraphPyClient::load_edge_file (std::string name, std::string filepath,
@@ -232,3 +325,4 @@ std::vector<GraphNode> GraphPyClient::pull_graph_list(std::string name,
232325}
233326}
234327}
328+
0 commit comments