diff --git a/paddle/fluid/distributed/service/graph_py_service.cc b/paddle/fluid/distributed/service/graph_py_service.cc index db5aaa947577a4..4a7d2706fef395 100644 --- a/paddle/fluid/distributed/service/graph_py_service.cc +++ b/paddle/fluid/distributed/service/graph_py_service.cc @@ -33,15 +33,17 @@ std::vector GraphPyService::split(std::string& str, } void GraphPyService::set_up(std::string ips_str, int shard_num, + std::vector node_types, std::vector edge_types) { set_shard_num(shard_num); // set_client_Id(client_id); // set_rank(rank); - this->table_id_map[std::string("")] = 0; - // Table 0 are for nodes + for (size_t table_id = 0; table_id < node_types.size(); table_id++) { + this->table_id_map[node_types[table_id]] = this->table_id_map.size(); + } for (size_t table_id = 0; table_id < edge_types.size(); table_id++) { - this->table_id_map[edge_types[table_id]] = int(table_id + 1); + this->table_id_map[edge_types[table_id]] = this->table_id_map.size(); } std::istringstream stream(ips_str); std::string ip; @@ -162,9 +164,15 @@ ::paddle::distributed::PSParameter GraphPyClient::GetWorkerProto() { } void GraphPyClient::load_edge_file(std::string name, std::string filepath, bool reverse) { - std::string params = "edge"; + // 'e' means load edge + std::string params = "e"; if (reverse) { - params += "|reverse"; + // 'e<' means load edges from $2 to $1 + params += "<"; + } + else { + // 'e>' means load edges from $1 to $2 + params += ">"; } if (this->table_id_map.count(name)) { uint32_t table_id = this->table_id_map[name]; @@ -175,7 +183,8 @@ void GraphPyClient::load_edge_file(std::string name, std::string filepath, } void GraphPyClient::load_node_file(std::string name, std::string filepath) { - std::string params = "node"; + // 'n' means load nodes and 'node_type' follows + std::string params = "n" + name; if (this->table_id_map.count(name)) { uint32_t table_id = this->table_id_map[name]; auto status = diff --git a/paddle/fluid/distributed/service/graph_py_service.h b/paddle/fluid/distributed/service/graph_py_service.h index 8f6c9f0ad0b64a..137f854b39e11b 100644 --- a/paddle/fluid/distributed/service/graph_py_service.h +++ b/paddle/fluid/distributed/service/graph_py_service.h @@ -82,14 +82,16 @@ class GraphPyService { int get_server_size(int server_size) { return server_size; } std::vector split(std::string& str, const char pattern); void set_up(std::string ips_str, int shard_num, + std::vector node_types, std::vector edge_types); }; class GraphPyServer : public GraphPyService { public: void set_up(std::string ips_str, int shard_num, + std::vector node_types, std::vector edge_types, int rank) { set_rank(rank); - GraphPyService::set_up(ips_str, shard_num, edge_types); + GraphPyService::set_up(ips_str, shard_num, node_types, edge_types); } int get_rank() { return rank; } void set_rank(int rank) { this->rank = rank; } @@ -107,9 +109,9 @@ class GraphPyServer : public GraphPyService { class GraphPyClient : public GraphPyService { public: void set_up(std::string ips_str, int shard_num, - std::vector edge_types, int client_id) { + std::vector node_types, std::vector edge_types, int client_id) { set_client_id(client_id); - GraphPyService::set_up(ips_str, shard_num, edge_types); + GraphPyService::set_up(ips_str, shard_num, node_types, edge_types); } std::shared_ptr get_ps_client() { return worker_ptr; diff --git a/paddle/fluid/distributed/table/common_graph_table.cc b/paddle/fluid/distributed/table/common_graph_table.cc index 5bd5ee268bbb5d..1094c0104cbe63 100644 --- a/paddle/fluid/distributed/table/common_graph_table.cc +++ b/paddle/fluid/distributed/table/common_graph_table.cc @@ -52,14 +52,16 @@ GraphNode *GraphShard::find_node(uint64_t id) { } int32_t GraphTable::load(const std::string &path, const std::string ¶m) { - auto cmd = paddle::string::split_string(param, "|"); - std::set cmd_set(cmd.begin(), cmd.end()); - bool reverse_edge = cmd_set.count(std::string("reverse")); - bool load_edge = cmd_set.count(std::string("edge")); + + bool load_edge = (param[0] == 'e'); + bool load_node = (param[0] == 'n'); if (load_edge) { + bool reverse_edge = (param[1] == '<'); return this->load_edges(path, reverse_edge); - } else { - return this->load_nodes(path); + } + if (load_node){ + std::string node_type = param.substr(1); + return this->load_nodes(path, node_type); } } @@ -104,7 +106,7 @@ int32_t GraphTable::get_nodes_ids_by_ranges( } return 0; } -int32_t GraphTable::load_nodes(const std::string &path) { +int32_t GraphTable::load_nodes(const std::string &path, std::string node_type) { auto paths = paddle::string::split_string(path, ";"); for (auto path : paths) { std::ifstream file(path); @@ -116,20 +118,26 @@ int32_t GraphTable::load_nodes(const std::string &path) { size_t shard_id = id % shard_num; if (shard_id >= shard_end || shard_id < shard_start) { - VLOG(0) << "will not load " << id << " from " << path + VLOG(4) << "will not load " << id << " from " << path << ", please check id distribution"; continue; } - std::string node_type = values[0]; + std::string nt = values[0]; + if (nt != node_type) { + continue; + } std::vector feature; - feature.push_back(node_type); for (size_t slice = 2; slice < values.size(); slice++) { feature.push_back(values[slice]); } - auto feat = paddle::string::join_strings(feature, '\t'); size_t index = shard_id - shard_start; - shards[index].add_node(id, feat); + if(feature.size() > 0) { + shards[index].add_node(id, paddle::string::join_strings(feature, '\t')); + } + else { + shards[index].add_node(id, std::string("")); + } } } return 0; @@ -159,7 +167,7 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) { size_t src_shard_id = src_id % shard_num; if (src_shard_id >= shard_end || src_shard_id < shard_start) { - VLOG(0) << "will not load " << src_id << " from " << path + VLOG(4) << "will not load " << src_id << " from " << path << ", please check id distribution"; continue; } diff --git a/paddle/fluid/distributed/table/common_graph_table.h b/paddle/fluid/distributed/table/common_graph_table.h index 7134f53c075b36..0aa67f08b5a38b 100644 --- a/paddle/fluid/distributed/table/common_graph_table.h +++ b/paddle/fluid/distributed/table/common_graph_table.h @@ -96,7 +96,7 @@ class GraphTable : public SparseTable { int32_t load_edges(const std::string &path, bool reverse); - int32_t load_nodes(const std::string &path); + int32_t load_nodes(const std::string &path, std::string node_type); GraphNode *find_node(uint64_t id); diff --git a/paddle/fluid/distributed/test/graph_node_test.cc b/paddle/fluid/distributed/test/graph_node_test.cc index 7b7412298da87f..56c3359a20857f 100644 --- a/paddle/fluid/distributed/test/graph_node_test.cc +++ b/paddle/fluid/distributed/test/graph_node_test.cc @@ -127,20 +127,46 @@ void testGraphToBuffer(); // std::string("59\ttreat\t45;0.34\t145;0.31\t112;0.21"), // std::string("97\tfood\t48;1.4\t247;0.31\t111;1.21")}; -std::string nodes[] = { +std::string edges[] = { std::string("37\t45\t0.34"), std::string("37\t145\t0.31"), std::string("37\t112\t0.21"), std::string("96\t48\t1.4"), std::string("96\t247\t0.31"), std::string("96\t111\t1.21"), std::string("59\t45\t0.34"), std::string("59\t145\t0.31"), std::string("59\t122\t0.21"), std::string("97\t48\t0.34"), - std::string("97\t247\t0.31"), std::string("97\t111\t0.21"), -}; -char file_name[] = "nodes.txt"; -void prepare_file(char file_name[]) { + std::string("97\t247\t0.31"), std::string("97\t111\t0.21")}; +char edge_file_name[] = "edges.txt"; + +std::string nodes[] = { + std::string("user\t37\t0.34"), + std::string("user\t96\t0.31"), + std::string("user\t59\t0.11"), + std::string("user\t97\t0.11"), + std::string("item\t45\t0.21"), + std::string("item\t145\t0.21"), + std::string("item\t112\t0.21"), + std::string("item\t48\t0.21"), + std::string("item\t247\t0.21"), + std::string("item\t111\t0.21"), + std::string("item\t45\t0.21"), + std::string("item\t145\t0.21"), + std::string("item\t122\t0.21"), + std::string("item\t48\t0.21"), + std::string("item\t247\t0.21"), + std::string("item\t111\t0.21")}; +char node_file_name[] = "nodes.txt"; + + +void prepare_file(char file_name[], bool load_edge) { std::ofstream ofile; ofile.open(file_name); - for (auto x : nodes) { - ofile << x << std::endl; + if(load_edge) { + for (auto x : edges) { + ofile << x << std::endl; + } + } else { + for (auto x : nodes) { + ofile << x << std::endl; + } } // for(int i = 0;i < 10;i++){ // for(int j = 0;j < 10;j++){ @@ -272,7 +298,8 @@ void RunClient(std::map>& void RunBrpcPushSparse() { setenv("http_proxy", "", 1); setenv("https_proxy", "", 1); - prepare_file(file_name); + prepare_file(edge_file_name, 1); + prepare_file(node_file_name, 0); auto ph_host = paddle::distributed::PSHost(ip_, port_, 0); host_sign_list_.push_back(ph_host.serialize_to_string()); @@ -294,7 +321,7 @@ void RunBrpcPushSparse() { /*-----------------------Test Server Init----------------------------------*/ auto pull_status = - worker_ptr_->load(0, std::string(file_name), std::string("edge")); + worker_ptr_->load(0, std::string(edge_file_name), std::string("e>")); srand(time(0)); pull_status.wait(); std::vector>> vs; @@ -333,10 +360,11 @@ void RunBrpcPushSparse() { distributed::GraphPyClient client1, client2; std::string ips_str = "127.0.0.1:4211;127.0.0.1:4212"; std::vector edge_types = {std::string("user2item")}; - server1.set_up(ips_str, 127, edge_types, 0); - server2.set_up(ips_str, 127, edge_types, 1); - client1.set_up(ips_str, 127, edge_types, 0); - client2.set_up(ips_str, 127, edge_types, 1); + std::vector node_types = {std::string("user"), std::string("item")}; + server1.set_up(ips_str, 127, node_types, edge_types, 0); + server2.set_up(ips_str, 127, node_types, edge_types, 1); + client1.set_up(ips_str, 127, node_types, edge_types, 0); + client2.set_up(ips_str, 127, node_types, edge_types, 1); server1.start_server(); std::cout << "first server done" << std::endl; server2.start_server(); @@ -346,11 +374,18 @@ void RunBrpcPushSparse() { client2.start_client(); std::cout << "first client done" << std::endl; std::cout << "started" << std::endl; - client1.load_edge_file(std::string("user2item"), std::string(file_name), 0); + client1.load_node_file(std::string("user"), std::string(node_file_name)); + client1.load_node_file(std::string("item"), std::string(node_file_name)); + client1.load_edge_file(std::string("user2item"), std::string(edge_file_name), 0); // client2.load_edge_file(std::string("user2item"), std::string(file_name), // 0); nodes.clear(); - nodes = client2.pull_graph_list(std::string("user2item"), 0, 1, 4); + nodes = client1.pull_graph_list(std::string("user"), 0, 1, 4); + + for (auto g : nodes) { + std::cout << "node_ids: " << g.get_id() << std::endl; + } + std::cout << "node_ids: " << nodes[0].get_id() << std::endl; ASSERT_EQ(nodes[0].get_id(), 59); nodes.clear(); vs = client1.batch_sample_k(std::string("user2item"), @@ -382,7 +417,8 @@ void RunBrpcPushSparse() { // for x in list: // print(x.get_id()) - std::remove(file_name); + std::remove(edge_file_name); + std::remove(node_file_name); LOG(INFO) << "Run stop_server"; worker_ptr_->stop_server(); LOG(INFO) << "Run finalize_worker";