Skip to content

Commit ba57877

Browse files
authored
Merge pull request #3 from Yelrose/develop
Add LoadNode; Change add_node;
2 parents d98be69 + 2feadfe commit ba57877

File tree

2 files changed

+65
-11
lines changed

2 files changed

+65
-11
lines changed

paddle/fluid/distributed/table/common_graph_table.cc

Lines changed: 58 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -64,21 +64,22 @@ size_t GraphShard::get_size() {
6464
return res;
6565
}
6666

67-
std::list<GraphNode *>::iterator GraphShard::add_node(GraphNode *node) {
68-
if (node_location.find(node->get_id()) != node_location.end())
69-
return node_location.find(node->get_id())->second;
67+
std::list<GraphNode *>::iterator GraphShard::add_node(uint64_t id, std::string feature) {
68+
if (node_location.find(id) != node_location.end())
69+
return node_location.find(id)->second;
7070

71-
int index = node->get_id() % shard_num % bucket_size;
71+
int index = id % shard_num % bucket_size;
72+
GraphNode *node = new GraphNode(id, feature);
7273

7374
std::list<GraphNode *>::iterator iter =
7475
bucket[index].insert(bucket[index].end(), node);
7576

76-
node_location[node->get_id()] = iter;
77+
node_location[id] = iter;
7778
return iter;
7879
}
7980

8081
void GraphShard::add_neighboor(uint64_t id, GraphEdge *edge) {
81-
(*add_node(new GraphNode(id, std::string(""))))->add_edge(edge);
82+
(*add_node(id, std::string("")))->add_edge(edge);
8283
}
8384

8485
GraphNode *GraphShard::find_node(uint64_t id) {
@@ -88,13 +89,55 @@ GraphNode *GraphShard::find_node(uint64_t id) {
8889

8990
int32_t GraphTable::load(const std::string &path, const std::string &param) {
9091
auto cmd = paddle::string::split_string<std::string>(param, "|");
91-
std::set<std::string> cmd_set(cmd.begin(), cmd.end());
92-
bool load_edge = cmd_set.count(std::string("edge"));
92+
std::set<std::string> cmd_set(cmd.begin(), cmd.end());
9393
bool reverse_edge = cmd_set.count(std::string("reverse"));
94-
VLOG(0) << "Reverse Edge " << reverse_edge;
94+
bool load_edge = cmd_set.count(std::string("edge"));
95+
if(load_edge) {
96+
return this -> load_edges(path, reverse_edge);
97+
}
98+
else {
99+
return this -> load_nodes(path);
100+
}
101+
}
102+
103+
int32_t GraphTable::load_nodes(const std::string &path) {
104+
auto paths = paddle::string::split_string<std::string>(path, ";");
105+
for (auto path : paths) {
106+
std::ifstream file(path);
107+
std::string line;
108+
while (std::getline(file, line)) {
109+
auto values = paddle::string::split_string<std::string>(line, "\t");
110+
if (values.size() < 2) continue;
111+
auto id = std::stoull(values[1]);
112+
113+
114+
size_t shard_id = id % shard_num;
115+
if (shard_id >= shard_end || shard_id < shard_start) {
116+
VLOG(0) << "will not load " << id << " from " << path
117+
<< ", please check id distribution";
118+
continue;
119+
120+
}
121+
122+
std::string node_type = values[0];
123+
std::vector<std::string > feature;
124+
feature.push_back(node_type);
125+
for(size_t slice = 2; slice < values.size(); slice ++) {
126+
feature.push_back(values[slice]);
127+
}
128+
auto feat = paddle::string::join_strings(feature, '\t');
129+
size_t index = shard_id - shard_start;
130+
shards[index].add_node(id, feat);
131+
132+
}
133+
}
134+
return 0;
135+
}
136+
137+
138+
int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) {
95139

96140
auto paths = paddle::string::split_string<std::string>(path, ";");
97-
VLOG(0) << paths.size();
98141
int count = 0;
99142

100143
for (auto path : paths) {
@@ -113,13 +156,15 @@ int32_t GraphTable::load(const std::string &path, const std::string &param) {
113156
if (values.size() == 3) {
114157
weight = std::stof(values[2]);
115158
}
159+
116160
size_t src_shard_id = src_id % shard_num;
117161

118162
if (src_shard_id >= shard_end || src_shard_id < shard_start) {
119163
VLOG(0) << "will not load " << src_id << " from " << path
120164
<< ", please check id distribution";
121165
continue;
122166
}
167+
123168
size_t index = src_shard_id - shard_start;
124169
GraphEdge *edge = new GraphEdge(dst_id, weight);
125170
shards[index].add_neighboor(src_id, edge);
@@ -128,6 +173,7 @@ int32_t GraphTable::load(const std::string &path, const std::string &param) {
128173
VLOG(0) << "Load Finished Total Edge Count " << count;
129174

130175
// Build Sampler j
176+
131177
for (auto &shard : shards) {
132178
auto bucket = shard.get_bucket();
133179
for (int i = 0; i < bucket.size(); i++) {
@@ -141,6 +187,7 @@ int32_t GraphTable::load(const std::string &path, const std::string &param) {
141187
}
142188
return 0;
143189
}
190+
144191
GraphNode *GraphTable::find_node(uint64_t id) {
145192
size_t shard_id = id % shard_num;
146193
if (shard_id >= shard_end || shard_id < shard_start) {
@@ -264,3 +311,4 @@ int32_t GraphTable::initialize() {
264311
}
265312
}
266313
};
314+

paddle/fluid/distributed/table/common_graph_table.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class GraphShard {
5252
}
5353
return -1;
5454
}
55-
std::list<GraphNode *>::iterator add_node(GraphNode *node);
55+
std::list<GraphNode *>::iterator add_node(uint64_t id, std::string feature);
5656
GraphNode *find_node(uint64_t id);
5757
void add_neighboor(uint64_t id, GraphEdge *edge);
5858
std::unordered_map<uint64_t, std::list<GraphNode *>::iterator>
@@ -74,7 +74,13 @@ class GraphTable : public SparseTable {
7474
virtual int32_t random_sample(uint64_t node_id, int sampe_size, char *&buffer,
7575
int &actual_size);
7676
virtual int32_t initialize();
77+
7778
int32_t load(const std::string &path, const std::string &param);
79+
80+
int32_t load_edges(const std::string &path, bool reverse);
81+
82+
int32_t load_nodes(const std::string &path);
83+
7884
GraphNode *find_node(uint64_t id);
7985

8086
virtual int32_t pull_sparse(float *values, const uint64_t *keys, size_t num) {

0 commit comments

Comments
 (0)