Skip to content
21 changes: 15 additions & 6 deletions paddle/fluid/distributed/service/graph_py_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,17 @@ std::vector<std::string> GraphPyService::split(std::string& str,
}

void GraphPyService::set_up(std::string ips_str, int shard_num,
std::vector<std::string> node_types,
std::vector<std::string> edge_types) {
set_shard_num(shard_num);
// set_client_Id(client_id);
// set_rank(rank);

this->table_id_map[std::string("")] = 0;
// Table 0 are for nodes
for (size_t table_id = 0; table_id < node_types.size(); table_id++) {
this->table_id_map[node_types[table_id]] = this->table_id_map.size();
}
for (size_t table_id = 0; table_id < edge_types.size(); table_id++) {
this->table_id_map[edge_types[table_id]] = int(table_id + 1);
this->table_id_map[edge_types[table_id]] = this->table_id_map.size();
}
std::istringstream stream(ips_str);
std::string ip;
Expand Down Expand Up @@ -162,9 +164,15 @@ ::paddle::distributed::PSParameter GraphPyClient::GetWorkerProto() {
}
void GraphPyClient::load_edge_file(std::string name, std::string filepath,
bool reverse) {
std::string params = "edge";
// 'e' means load edge
std::string params = "e";
if (reverse) {
params += "|reverse";
// 'e<' means load edges from $2 to $1
params += "<";
}
else {
// 'e>' means load edges from $1 to $2
params += ">";
}
if (this->table_id_map.count(name)) {
uint32_t table_id = this->table_id_map[name];
Expand All @@ -175,7 +183,8 @@ void GraphPyClient::load_edge_file(std::string name, std::string filepath,
}

void GraphPyClient::load_node_file(std::string name, std::string filepath) {
std::string params = "node";
// 'n' means load nodes and 'node_type' follows
std::string params = "n" + name;
if (this->table_id_map.count(name)) {
uint32_t table_id = this->table_id_map[name];
auto status =
Expand Down
8 changes: 5 additions & 3 deletions paddle/fluid/distributed/service/graph_py_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,16 @@ class GraphPyService {
int get_server_size(int server_size) { return server_size; }
std::vector<std::string> split(std::string& str, const char pattern);
void set_up(std::string ips_str, int shard_num,
std::vector<std::string> node_types,
std::vector<std::string> edge_types);
};
class GraphPyServer : public GraphPyService {
public:
void set_up(std::string ips_str, int shard_num,
std::vector<std::string> node_types,
std::vector<std::string> edge_types, int rank) {
set_rank(rank);
GraphPyService::set_up(ips_str, shard_num, edge_types);
GraphPyService::set_up(ips_str, shard_num, node_types, edge_types);
}
int get_rank() { return rank; }
void set_rank(int rank) { this->rank = rank; }
Expand All @@ -107,9 +109,9 @@ class GraphPyServer : public GraphPyService {
class GraphPyClient : public GraphPyService {
public:
void set_up(std::string ips_str, int shard_num,
std::vector<std::string> edge_types, int client_id) {
std::vector<std::string> node_types, std::vector<std::string> edge_types, int client_id) {
set_client_id(client_id);
GraphPyService::set_up(ips_str, shard_num, edge_types);
GraphPyService::set_up(ips_str, shard_num, node_types, edge_types);
}
std::shared_ptr<paddle::distributed::PSClient> get_ps_client() {
return worker_ptr;
Expand Down
34 changes: 21 additions & 13 deletions paddle/fluid/distributed/table/common_graph_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,16 @@ GraphNode *GraphShard::find_node(uint64_t id) {
}

int32_t GraphTable::load(const std::string &path, const std::string &param) {
auto cmd = paddle::string::split_string<std::string>(param, "|");
std::set<std::string> cmd_set(cmd.begin(), cmd.end());
bool reverse_edge = cmd_set.count(std::string("reverse"));
bool load_edge = cmd_set.count(std::string("edge"));

bool load_edge = (param[0] == 'e');
bool load_node = (param[0] == 'n');
if (load_edge) {
bool reverse_edge = (param[1] == '<');
return this->load_edges(path, reverse_edge);
} else {
return this->load_nodes(path);
}
if (load_node){
std::string node_type = param.substr(1);
return this->load_nodes(path, node_type);
}
}

Expand Down Expand Up @@ -104,7 +106,7 @@ int32_t GraphTable::get_nodes_ids_by_ranges(
}
return 0;
}
int32_t GraphTable::load_nodes(const std::string &path) {
int32_t GraphTable::load_nodes(const std::string &path, std::string node_type) {
auto paths = paddle::string::split_string<std::string>(path, ";");
for (auto path : paths) {
std::ifstream file(path);
Expand All @@ -116,20 +118,26 @@ int32_t GraphTable::load_nodes(const std::string &path) {

size_t shard_id = id % shard_num;
if (shard_id >= shard_end || shard_id < shard_start) {
VLOG(0) << "will not load " << id << " from " << path
VLOG(4) << "will not load " << id << " from " << path
<< ", please check id distribution";
continue;
}

std::string node_type = values[0];
std::string nt = values[0];
if (nt != node_type) {
continue;
}
std::vector<std::string> feature;
feature.push_back(node_type);
for (size_t slice = 2; slice < values.size(); slice++) {
feature.push_back(values[slice]);
}
auto feat = paddle::string::join_strings(feature, '\t');
size_t index = shard_id - shard_start;
shards[index].add_node(id, feat);
if(feature.size() > 0) {
shards[index].add_node(id, paddle::string::join_strings(feature, '\t'));
}
else {
shards[index].add_node(id, std::string(""));
}
}
}
return 0;
Expand Down Expand Up @@ -159,7 +167,7 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) {
size_t src_shard_id = src_id % shard_num;

if (src_shard_id >= shard_end || src_shard_id < shard_start) {
VLOG(0) << "will not load " << src_id << " from " << path
VLOG(4) << "will not load " << src_id << " from " << path
<< ", please check id distribution";
continue;
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/distributed/table/common_graph_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class GraphTable : public SparseTable {

int32_t load_edges(const std::string &path, bool reverse);

int32_t load_nodes(const std::string &path);
int32_t load_nodes(const std::string &path, std::string node_type);

GraphNode *find_node(uint64_t id);

Expand Down
68 changes: 52 additions & 16 deletions paddle/fluid/distributed/test/graph_node_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,20 +127,46 @@ void testGraphToBuffer();
// std::string("59\ttreat\t45;0.34\t145;0.31\t112;0.21"),
// std::string("97\tfood\t48;1.4\t247;0.31\t111;1.21")};

std::string nodes[] = {
std::string edges[] = {
std::string("37\t45\t0.34"), std::string("37\t145\t0.31"),
std::string("37\t112\t0.21"), std::string("96\t48\t1.4"),
std::string("96\t247\t0.31"), std::string("96\t111\t1.21"),
std::string("59\t45\t0.34"), std::string("59\t145\t0.31"),
std::string("59\t122\t0.21"), std::string("97\t48\t0.34"),
std::string("97\t247\t0.31"), std::string("97\t111\t0.21"),
};
char file_name[] = "nodes.txt";
void prepare_file(char file_name[]) {
std::string("97\t247\t0.31"), std::string("97\t111\t0.21")};
char edge_file_name[] = "edges.txt";

std::string nodes[] = {
std::string("user\t37\t0.34"),
std::string("user\t96\t0.31"),
std::string("user\t59\t0.11"),
std::string("user\t97\t0.11"),
std::string("item\t45\t0.21"),
std::string("item\t145\t0.21"),
std::string("item\t112\t0.21"),
std::string("item\t48\t0.21"),
std::string("item\t247\t0.21"),
std::string("item\t111\t0.21"),
std::string("item\t45\t0.21"),
std::string("item\t145\t0.21"),
std::string("item\t122\t0.21"),
std::string("item\t48\t0.21"),
std::string("item\t247\t0.21"),
std::string("item\t111\t0.21")};
char node_file_name[] = "nodes.txt";


void prepare_file(char file_name[], bool load_edge) {
std::ofstream ofile;
ofile.open(file_name);
for (auto x : nodes) {
ofile << x << std::endl;
if(load_edge) {
for (auto x : edges) {
ofile << x << std::endl;
}
} else {
for (auto x : nodes) {
ofile << x << std::endl;
}
}
// for(int i = 0;i < 10;i++){
// for(int j = 0;j < 10;j++){
Expand Down Expand Up @@ -272,7 +298,8 @@ void RunClient(std::map<uint64_t, std::vector<paddle::distributed::Region>>&
void RunBrpcPushSparse() {
setenv("http_proxy", "", 1);
setenv("https_proxy", "", 1);
prepare_file(file_name);
prepare_file(edge_file_name, 1);
prepare_file(node_file_name, 0);
auto ph_host = paddle::distributed::PSHost(ip_, port_, 0);
host_sign_list_.push_back(ph_host.serialize_to_string());

Expand All @@ -294,7 +321,7 @@ void RunBrpcPushSparse() {

/*-----------------------Test Server Init----------------------------------*/
auto pull_status =
worker_ptr_->load(0, std::string(file_name), std::string("edge"));
worker_ptr_->load(0, std::string(edge_file_name), std::string("e>"));
srand(time(0));
pull_status.wait();
std::vector<std::vector<std::pair<uint64_t, float>>> vs;
Expand Down Expand Up @@ -333,10 +360,11 @@ void RunBrpcPushSparse() {
distributed::GraphPyClient client1, client2;
std::string ips_str = "127.0.0.1:4211;127.0.0.1:4212";
std::vector<std::string> edge_types = {std::string("user2item")};
server1.set_up(ips_str, 127, edge_types, 0);
server2.set_up(ips_str, 127, edge_types, 1);
client1.set_up(ips_str, 127, edge_types, 0);
client2.set_up(ips_str, 127, edge_types, 1);
std::vector<std::string> node_types = {std::string("user"), std::string("item")};
server1.set_up(ips_str, 127, node_types, edge_types, 0);
server2.set_up(ips_str, 127, node_types, edge_types, 1);
client1.set_up(ips_str, 127, node_types, edge_types, 0);
client2.set_up(ips_str, 127, node_types, edge_types, 1);
server1.start_server();
std::cout << "first server done" << std::endl;
server2.start_server();
Expand All @@ -346,11 +374,18 @@ void RunBrpcPushSparse() {
client2.start_client();
std::cout << "first client done" << std::endl;
std::cout << "started" << std::endl;
client1.load_edge_file(std::string("user2item"), std::string(file_name), 0);
client1.load_node_file(std::string("user"), std::string(node_file_name));
client1.load_node_file(std::string("item"), std::string(node_file_name));
client1.load_edge_file(std::string("user2item"), std::string(edge_file_name), 0);
// client2.load_edge_file(std::string("user2item"), std::string(file_name),
// 0);
nodes.clear();
nodes = client2.pull_graph_list(std::string("user2item"), 0, 1, 4);
nodes = client1.pull_graph_list(std::string("user"), 0, 1, 4);

for (auto g : nodes) {
std::cout << "node_ids: " << g.get_id() << std::endl;
}
std::cout << "node_ids: " << nodes[0].get_id() << std::endl;
ASSERT_EQ(nodes[0].get_id(), 59);
nodes.clear();
vs = client1.batch_sample_k(std::string("user2item"),
Expand Down Expand Up @@ -382,7 +417,8 @@ void RunBrpcPushSparse() {
// for x in list:
// print(x.get_id())

std::remove(file_name);
std::remove(edge_file_name);
std::remove(node_file_name);
LOG(INFO) << "Run stop_server";
worker_ptr_->stop_server();
LOG(INFO) << "Run finalize_worker";
Expand Down