@@ -64,21 +64,22 @@ size_t GraphShard::get_size() {
6464 return res;
6565}
6666
67- std::list<GraphNode *>::iterator GraphShard::add_node (GraphNode *node ) {
68- if (node_location.find (node-> get_id () ) != node_location.end ())
69- return node_location.find (node-> get_id () )->second ;
67+ std::list<GraphNode *>::iterator GraphShard::add_node (uint64_t id, std::string feature ) {
68+ if (node_location.find (id ) != node_location.end ())
69+ return node_location.find (id )->second ;
7070
71- int index = node->get_id () % shard_num % bucket_size;
71+ int index = id % shard_num % bucket_size;
72+ GraphNode *node = new GraphNode (id, feature);
7273
7374 std::list<GraphNode *>::iterator iter =
7475 bucket[index].insert (bucket[index].end (), node);
7576
76- node_location[node-> get_id () ] = iter;
77+ node_location[id ] = iter;
7778 return iter;
7879}
7980
8081void GraphShard::add_neighboor (uint64_t id, GraphEdge *edge) {
81- (*add_node (new GraphNode ( id, std::string (" " ) )))->add_edge (edge);
82+ (*add_node (id, std::string (" " )))->add_edge (edge);
8283}
8384
8485GraphNode *GraphShard::find_node (uint64_t id) {
@@ -88,13 +89,55 @@ GraphNode *GraphShard::find_node(uint64_t id) {
8889
8990int32_t GraphTable::load (const std::string &path, const std::string ¶m) {
9091 auto cmd = paddle::string::split_string<std::string>(param, " |" );
91- std::set<std::string> cmd_set (cmd.begin (), cmd.end ());
92- bool load_edge = cmd_set.count (std::string (" edge" ));
92+ std::set<std::string> cmd_set (cmd.begin (), cmd.end ());
9393 bool reverse_edge = cmd_set.count (std::string (" reverse" ));
94- VLOG (0 ) << " Reverse Edge " << reverse_edge;
94+ bool load_edge = cmd_set.count (std::string (" edge" ));
95+ if (load_edge) {
96+ return this -> load_edges (path, reverse_edge);
97+ }
98+ else {
99+ return this -> load_nodes (path);
100+ }
101+ }
102+
103+ int32_t GraphTable::load_nodes (const std::string &path) {
104+ auto paths = paddle::string::split_string<std::string>(path, " ;" );
105+ for (auto path : paths) {
106+ std::ifstream file (path);
107+ std::string line;
108+ while (std::getline (file, line)) {
109+ auto values = paddle::string::split_string<std::string>(line, " \t " );
110+ if (values.size () < 2 ) continue ;
111+ auto id = std::stoull (values[1 ]);
112+
113+
114+ size_t shard_id = id % shard_num;
115+ if (shard_id >= shard_end || shard_id < shard_start) {
116+ VLOG (0 ) << " will not load " << id << " from " << path
117+ << " , please check id distribution" ;
118+ continue ;
119+
120+ }
121+
122+ std::string node_type = values[0 ];
123+ std::vector<std::string > feature;
124+ feature.push_back (node_type);
125+ for (size_t slice = 2 ; slice < values.size (); slice ++) {
126+ feature.push_back (values[slice]);
127+ }
128+ auto feat = paddle::string::join_strings (feature, ' \t ' );
129+ size_t index = shard_id - shard_start;
130+ shards[index].add_node (id, feat);
131+
132+ }
133+ }
134+ return 0 ;
135+ }
136+
137+
138+ int32_t GraphTable::load_edges (const std::string &path, bool reverse_edge) {
95139
96140 auto paths = paddle::string::split_string<std::string>(path, " ;" );
97- VLOG (0 ) << paths.size ();
98141 int count = 0 ;
99142
100143 for (auto path : paths) {
@@ -113,13 +156,15 @@ int32_t GraphTable::load(const std::string &path, const std::string ¶m) {
113156 if (values.size () == 3 ) {
114157 weight = std::stof (values[2 ]);
115158 }
159+
116160 size_t src_shard_id = src_id % shard_num;
117161
118162 if (src_shard_id >= shard_end || src_shard_id < shard_start) {
119163 VLOG (0 ) << " will not load " << src_id << " from " << path
120164 << " , please check id distribution" ;
121165 continue ;
122166 }
167+
123168 size_t index = src_shard_id - shard_start;
124169 GraphEdge *edge = new GraphEdge (dst_id, weight);
125170 shards[index].add_neighboor (src_id, edge);
@@ -128,6 +173,7 @@ int32_t GraphTable::load(const std::string &path, const std::string ¶m) {
128173 VLOG (0 ) << " Load Finished Total Edge Count " << count;
129174
130175 // Build Sampler j
176+
131177 for (auto &shard : shards) {
132178 auto bucket = shard.get_bucket ();
133179 for (int i = 0 ; i < bucket.size (); i++) {
@@ -141,6 +187,7 @@ int32_t GraphTable::load(const std::string &path, const std::string ¶m) {
141187 }
142188 return 0 ;
143189}
190+
144191GraphNode *GraphTable::find_node (uint64_t id) {
145192 size_t shard_id = id % shard_num;
146193 if (shard_id >= shard_end || shard_id < shard_start) {
@@ -264,3 +311,4 @@ int32_t GraphTable::initialize() {
264311}
265312}
266313};
314+
0 commit comments