Skip to content

Commit 94025b9

Browse files
authored
Merge pull request #1 from Yelrose/develop
Add Simple Loading for Direct Graph
2 parents 3f32bf1 + 7320994 commit 94025b9

File tree

5 files changed

+117
-61
lines changed

5 files changed

+117
-61
lines changed

paddle/fluid/distributed/service/graph_py_service.cc

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,16 @@ std::vector<std::string> GraphPyService::split(std::string &str,
2727
}
2828

2929
void GraphPyService::set_up(std::string ips_str, int shard_num, int rank,
30-
int client_id, uint32_t table_id) {
30+
int client_id, std::vector<std::string> edge_types) {
3131
set_shard_num(shard_num);
3232
set_client_Id(client_id);
3333
set_rank(rank);
34-
this->table_id = table_id;
34+
35+
this -> table_id_map[std::string("")] = 0;
36+
// Table 0 are for nodes
37+
for(size_t table_id = 0; table_id < edge_types.size(); table_id ++ ) {
38+
this -> table_id_map[edge_types[table_id]] = int(table_id + 1);
39+
}
3540
server_thread = client_thread = NULL;
3641
std::istringstream stream(ips_str);
3742
std::string ip;
@@ -47,10 +52,10 @@ void GraphPyService::set_up(std::string ips_str, int shard_num, int rank,
4752
host_sign_list.push_back(ph_host.serialize_to_string());
4853
index++;
4954
}
50-
VLOG(0) << "IN set up rank = " << rank;
55+
//VLOG(0) << "IN set up rank = " << rank;
5156
start_client();
5257
start_server(server_list[rank], std::stoul(port_list[rank]));
5358
sleep(1);
5459
}
5560
}
56-
}
61+
}

paddle/fluid/distributed/service/graph_py_service.h

Lines changed: 55 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include <string>
2323
#include <thread> // NOLINT
2424
#include <vector>
25+
#include <unordered_map>
2526
#include "google/protobuf/text_format.h"
2627

2728
#include "gtest/gtest.h"
@@ -46,7 +47,7 @@ class GraphPyService {
4647
std::vector<int> keys;
4748
std::vector<std::string> server_list, port_list, host_sign_list;
4849
int server_size, shard_num, rank, client_id;
49-
uint32_t table_id;
50+
std::unordered_map<std::string, uint32_t > table_id_map;
5051
std::thread *server_thread, *client_thread;
5152

5253
std::shared_ptr<paddle::distributed::PSServer> pserver_ptr;
@@ -67,7 +68,7 @@ class GraphPyService {
6768
int get_shard_num() { return shard_num; }
6869
void set_shard_num(int shard_num) { this->shard_num = shard_num; }
6970
void GetDownpourSparseTableProto(
70-
::paddle::distributed::TableParameter* sparse_table_proto) {
71+
::paddle::distributed::TableParameter* sparse_table_proto, uint32_t table_id) {
7172
sparse_table_proto->set_table_id(table_id);
7273
sparse_table_proto->set_table_class("GraphTable");
7374
sparse_table_proto->set_shard_num(shard_num);
@@ -96,10 +97,14 @@ class GraphPyService {
9697
server_service_proto->set_start_server_port(0);
9798
server_service_proto->set_server_thread_num(12);
9899

99-
::paddle::distributed::TableParameter* sparse_table_proto =
100-
downpour_server_proto->add_downpour_table_param();
101-
GetDownpourSparseTableProto(sparse_table_proto);
100+
for(auto& tuple : this -> table_id_map) {
101+
::paddle::distributed::TableParameter* sparse_table_proto =
102+
downpour_server_proto->add_downpour_table_param();
103+
GetDownpourSparseTableProto(sparse_table_proto, tuple.second);
104+
}
105+
102106
return server_fleet_desc;
107+
103108
}
104109

105110
::paddle::distributed::PSParameter GetWorkerProto() {
@@ -111,9 +116,11 @@ class GraphPyService {
111116
::paddle::distributed::DownpourWorkerParameter* downpour_worker_proto =
112117
worker_proto->mutable_downpour_worker_param();
113118

114-
::paddle::distributed::TableParameter* worker_sparse_table_proto =
115-
downpour_worker_proto->add_downpour_table_param();
116-
GetDownpourSparseTableProto(worker_sparse_table_proto);
119+
for(auto& tuple : this -> table_id_map) {
120+
::paddle::distributed::TableParameter* worker_sparse_table_proto =
121+
downpour_worker_proto->add_downpour_table_param();
122+
GetDownpourSparseTableProto(worker_sparse_table_proto, tuple.second);
123+
}
117124

118125
::paddle::distributed::ServerParameter* server_proto =
119126
worker_fleet_desc.mutable_server_param();
@@ -127,34 +134,59 @@ class GraphPyService {
127134
server_service_proto->set_start_server_port(0);
128135
server_service_proto->set_server_thread_num(12);
129136

130-
::paddle::distributed::TableParameter* server_sparse_table_proto =
131-
downpour_server_proto->add_downpour_table_param();
132-
GetDownpourSparseTableProto(server_sparse_table_proto);
137+
for(auto& tuple : this -> table_id_map) {
138+
::paddle::distributed::TableParameter* sparse_table_proto =
139+
downpour_server_proto->add_downpour_table_param();
140+
GetDownpourSparseTableProto(sparse_table_proto, tuple.second);
141+
}
133142

134143
return worker_fleet_desc;
135144
}
136145
void set_server_size(int server_size) { this->server_size = server_size; }
137146
int get_server_size(int server_size) { return server_size; }
138147
std::vector<std::string> split(std::string& str, const char pattern);
139148

140-
void load_file(std::string filepath) {
141-
auto status =
142-
get_ps_client()->load(table_id, std::string(filepath), std::string(""));
143-
status.wait();
149+
void load_edge_file(std::string name, std::string filepath, bool reverse) {
150+
std::string params = "edge";
151+
if(reverse) {
152+
params += "|reverse";
153+
}
154+
if (this -> table_id_map.count(name)) {
155+
uint32_t table_id = this -> table_id_map[name];
156+
auto status =
157+
get_ps_client()->load(table_id, std::string(filepath), params);
158+
status.wait();
159+
}
160+
}
161+
162+
void load_node_file(std::string name, std::string filepath) {
163+
std::string params = "node";
164+
if (this -> table_id_map.count(name)) {
165+
uint32_t table_id = this -> table_id_map[name];
166+
auto status =
167+
get_ps_client()->load(table_id, std::string(filepath), params);
168+
status.wait();
169+
}
144170
}
145171

146-
std::vector<GraphNode> sample_k(uint64_t node_id, int sample_size) {
172+
std::vector<GraphNode> sample_k(std::string name, uint64_t node_id, int sample_size) {
147173
std::vector<GraphNode> v;
148-
auto status = worker_ptr->sample(table_id, node_id, sample_size, v);
149-
status.wait();
174+
if (this -> table_id_map.count(name)) {
175+
uint32_t table_id = this -> table_id_map[name];
176+
auto status = worker_ptr->sample(table_id, node_id, sample_size, v);
177+
status.wait();
178+
}
150179
return v;
151180
}
152-
std::vector<GraphNode> pull_graph_list(int server_index, int start,
181+
std::vector<GraphNode> pull_graph_list(std::string name, int server_index, int start,
153182
int size) {
154183
std::vector<GraphNode> res;
155-
auto status =
156-
worker_ptr->pull_graph_list(table_id, server_index, start, size, res);
157-
status.wait();
184+
if (this -> table_id_map.count(name)) {
185+
uint32_t table_id = this -> table_id_map[name];
186+
auto status =
187+
worker_ptr->pull_graph_list(table_id, server_index, start, size, res);
188+
status.wait();
189+
}
158190
return res;
159191
}
160192
void start_server(std::string ip, uint32_t port) {
@@ -197,7 +229,7 @@ class GraphPyService {
197229
worker_ptr->configure(worker_proto, dense_regions, _ps_env, client_id);
198230
}
199231
void set_up(std::string ips_str, int shard_num, int rank, int client_id,
200-
uint32_t table_id);
232+
std::vector<std::string> edge_types);
201233
void set_keys(std::vector<int> keys) { // just for test
202234
this->keys = keys;
203235
}

paddle/fluid/distributed/table/common_graph_table.cc

Lines changed: 45 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,16 @@
1515
#include "paddle/fluid/distributed/table/common_graph_table.h"
1616
#include <algorithm>
1717
#include <sstream>
18+
#include <time.h>
19+
#include <set>
1820
#include "paddle/fluid/distributed/common/utils.h"
1921
#include "paddle/fluid/string/printf.h"
2022
#include "paddle/fluid/string/string_helper.h"
2123
namespace paddle {
2224
namespace distributed {
25+
2326
int GraphShard::bucket_low_bound = 11;
27+
2428
std::vector<GraphNode *> GraphShard::get_batch(int start, int total_size) {
2529
if (start < 0) start = 0;
2630
int size = 0, cur_size;
@@ -51,68 +55,81 @@ std::vector<GraphNode *> GraphShard::get_batch(int start, int total_size) {
5155
}
5256
return res;
5357
}
58+
5459
size_t GraphShard::get_size() {
5560
size_t res = 0;
5661
for (int i = 0; i < bucket_size; i++) {
5762
res += bucket[i].size();
5863
}
5964
return res;
6065
}
66+
6167
std::list<GraphNode *>::iterator GraphShard::add_node(GraphNode *node) {
6268
if (node_location.find(node->get_id()) != node_location.end())
6369
return node_location.find(node->get_id())->second;
70+
6471
int index = node->get_id() % shard_num % bucket_size;
72+
6573
std::list<GraphNode *>::iterator iter =
6674
bucket[index].insert(bucket[index].end(), node);
75+
6776
node_location[node->get_id()] = iter;
6877
return iter;
6978
}
79+
7080
void GraphShard::add_neighboor(uint64_t id, GraphEdge *edge) {
7181
(*add_node(new GraphNode(id, std::string(""))))->add_edge(edge);
7282
}
83+
7384
GraphNode *GraphShard::find_node(uint64_t id) {
7485
if (node_location.find(id) == node_location.end()) return NULL;
7586
return *(node_location[id]);
7687
}
88+
7789
int32_t GraphTable::load(const std::string &path, const std::string &param) {
90+
auto cmd = paddle::string::split_string<std::string>(param, "|");
91+
std::set<std::string> cmd_set(cmd.begin(), cmd.end());
92+
bool load_edge = cmd_set.count(std::string("edge"));
93+
bool reverse_edge = cmd_set.count(std::string("reverse"));
94+
VLOG(0) << "Reverse Edge " << reverse_edge;
95+
7896
auto paths = paddle::string::split_string<std::string>(path, ";");
7997
VLOG(0) << paths.size();
98+
int count = 0;
99+
80100
for (auto path : paths) {
81101
std::ifstream file(path);
82102
std::string line;
83103
while (std::getline(file, line)) {
84104
auto values = paddle::string::split_string<std::string>(line, "\t");
105+
count ++;
85106
if (values.size() < 2) continue;
86-
auto id = std::stoull(values[0]);
87-
size_t shard_id = id % shard_num;
88-
if (shard_id >= shard_end || shard_id < shard_start) {
89-
VLOG(0) << "will not load " << id << " from " << path
107+
auto src_id = std::stoull(values[0]);
108+
auto dst_id = std::stoull(values[1]);
109+
if(reverse_edge) {
110+
std::swap(src_id, dst_id);
111+
}
112+
double weight = 0;
113+
if (values.size() == 3) {
114+
weight = std::stod(values[2]);
115+
}
116+
size_t src_shard_id = src_id % shard_num;
117+
118+
if (src_shard_id >= shard_end || src_shard_id < shard_start) {
119+
VLOG(0) << "will not load " << src_id << " from " << path
90120
<< ", please check id distribution";
91121
continue;
122+
92123
}
93-
size_t index = shard_id - shard_start;
94-
// GraphNodeType type = GraphNode::get_graph_node_type(values[1]);
95-
// VLOG(0)<<"shards's size = "<<shards.size()<<" values' size =
96-
// "<<values.size();
97-
// VLOG(0)<<"add to index "<<index<<" table rank = "<<_shard_idx;
98-
shards[index].add_node(new GraphNode(id, values[1]));
99-
// VLOG(0)<<"checking added of rank "<<_shard_idx<<" shard "<<index<<"
100-
// "<<cc->get_id();
101-
for (size_t i = 2; i < values.size(); i++) {
102-
auto edge_arr =
103-
paddle::string::split_string<std::string>(values[i], ";");
104-
if (edge_arr.size() == 2) {
105-
// VLOG(0)<<"edge content "<<edge_arr[0]<<" "<<edge_arr[1]<<"
106-
// "<<edge_arr[2];
107-
auto edge_id = std::stoull(edge_arr[0]);
108-
auto weight = std::stod(edge_arr[1]);
109-
// VLOG(0)<<"edge_id "<<edge_id<<" weight "<<weight;
110-
GraphEdge *edge = new GraphEdge(edge_id, weight);
111-
shards[index].add_neighboor(id, edge);
112-
}
113-
}
124+
size_t index = src_shard_id - shard_start;
125+
GraphEdge *edge = new GraphEdge(dst_id, weight);
126+
shards[index].add_neighboor(src_id, edge);
114127
}
115-
for (auto &shard : shards) {
128+
}
129+
VLOG(0) << "Load Finished Total Edge Count " << count;
130+
131+
// Build Sampler j
132+
for (auto &shard : shards) {
116133
auto bucket = shard.get_bucket();
117134
for (int i = 0; i < bucket.size(); i++) {
118135
std::list<GraphNode *>::iterator iter = bucket[i].begin();
@@ -122,7 +139,6 @@ int32_t GraphTable::load(const std::string &path, const std::string &param) {
122139
iter++;
123140
}
124141
}
125-
}
126142
}
127143
return 0;
128144
}
@@ -144,6 +160,7 @@ int32_t GraphTable::random_sample(uint64_t node_id, int sample_size,
144160
char *&buffer, int &actual_size) {
145161
return _shards_task_pool[get_thread_pool_index(node_id)]
146162
->enqueue([&]() -> int {
163+
147164
GraphNode *node = find_node(node_id);
148165
if (node == NULL) {
149166
actual_size = 0;
@@ -275,4 +292,4 @@ int32_t GraphTable::initialize() {
275292
return 0;
276293
}
277294
}
278-
};
295+
};

paddle/fluid/distributed/test/graph_node_test.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -244,14 +244,15 @@ void RunBrpcPushSparse() {
244244

245245
distributed::GraphPyService gps1, gps2;
246246
std::string ips_str = "127.0.0.1:4211;127.0.0.1:4212";
247-
gps1.set_up(ips_str, 127, 0, 0, 0);
248-
gps2.set_up(ips_str, 127, 1, 1, 0);
249-
gps1.load_file(std::string(file_name));
247+
std::vector<std::string> edge_types = { std::string("user2item")};
248+
gps1.set_up(ips_str, 127, 0, 0, edge_types);
249+
gps2.set_up(ips_str, 127, 1, 1, edge_types);
250+
gps1.load_edge_file(std::string("user2item"), std::string(file_name), 0);
250251
v.clear();
251-
v = gps2.pull_graph_list(0, 1, 4);
252+
v = gps2.pull_graph_list(std::string("user2item"), 0, 1, 4);
252253
ASSERT_EQ(v[0].get_id(), 59);
253254
v.clear();
254-
v = gps2.sample_k(96, 4);
255+
v = gps2.sample_k(std::string("user2item"), 96, 4);
255256
ASSERT_EQ(v.size(), 3);
256257
// to test in python,try this:
257258
// from paddle.fluid.core import GraphPyService

paddle/fluid/pybind/fleet_py.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,8 @@ void BindGraphNode(py::module* m) {
165165
void BindGraphService(py::module* m) {
166166
py::class_<GraphPyService>(*m, "GraphPyService")
167167
.def(py::init<>())
168-
.def("load_file", &GraphPyService::load_file)
168+
.def("load_edge_file", &GraphPyService::load_edge_file)
169+
.def("load_node_file", &GraphPyService::load_node_file)
169170
.def("set_up", &GraphPyService::set_up)
170171
.def("pull_graph_list", &GraphPyService::pull_graph_list)
171172
.def("sample_k", &GraphPyService::sample_k);

0 commit comments

Comments
 (0)