Skip to content

Commit 578e305

Browse files
authored
Merge pull request #14 from WeiyueSu/FeatureNode
get_node_feat return py:bytes
2 parents bb48ece + 6f4223c commit 578e305

File tree

3 files changed

+30
-7
lines changed

3 files changed

+30
-7
lines changed

paddle/fluid/distributed/table/common_graph_table.cc

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -142,17 +142,14 @@ int32_t GraphTable::load_nodes(const std::string &path, std::string node_type) {
142142

143143
auto node = shards[index].add_feature_node(id);
144144

145-
//auto mutable_feature = node->get_mutable_feature();
146-
147-
//mutable_feature.clear();
148-
//mutable_feature.resize(this->feat_name.size());
149145
node->set_feature_size(feat_name.size());
150146

151147
for (size_t slice = 2; slice < values.size(); slice++) {
152148
auto feat = this->parse_feature(values[slice]);
153-
if(feat.first > 0) {
154-
//mutable_feature[feat.first] = feat.second;
149+
if (feat.first >= 0) {
155150
node->set_feature(feat.first, feat.second);
151+
} else{
152+
VLOG(4) << "Node feature: " << values[slice] << " not in feature_map.";
156153
}
157154
}
158155
}

paddle/fluid/distributed/test/graph_node_test.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,22 @@ void RunBrpcPushSparse() {
526526
std::cout << "get_node_feat: " << node_feat[1][0] << std::endl;
527527
std::cout << "get_node_feat: " << node_feat[1][1] << std::endl;
528528

529+
// Test string
530+
node_ids.clear();
531+
node_ids.push_back(37);
532+
node_ids.push_back(96);
533+
//std::vector<std::string> feature_names;
534+
feature_names.clear();
535+
feature_names.push_back(std::string("a"));
536+
feature_names.push_back(std::string("b"));
537+
node_feat = client1.get_node_feat(std::string("user"), node_ids, feature_names);
538+
ASSERT_EQ(node_feat.size(), 2);
539+
ASSERT_EQ(node_feat[0].size(), 2);
540+
std::cout << "get_node_feat: " << node_feat[0][0].size() << std::endl;
541+
std::cout << "get_node_feat: " << node_feat[0][1].size() << std::endl;
542+
std::cout << "get_node_feat: " << node_feat[1][0].size() << std::endl;
543+
std::cout << "get_node_feat: " << node_feat[1][1].size() << std::endl;
544+
529545
std::remove(edge_file_name);
530546
std::remove(node_file_name);
531547
LOG(INFO) << "Run stop_server";

paddle/fluid/pybind/fleet_py.cc

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,17 @@ void BindGraphPyClient(py::module* m) {
191191
.def("start_client", &GraphPyClient::start_client)
192192
.def("batch_sample_neighboors", &GraphPyClient::batch_sample_neighboors)
193193
.def("random_sample_nodes", &GraphPyClient::random_sample_nodes)
194-
.def("get_node_feat", &GraphPyClient::get_node_feat)
194+
.def("get_node_feat", [](GraphPyClient& self, std::string node_type, std::vector<uint64_t> node_ids,
195+
std::vector<std::string> feature_names){
196+
auto feats = self.get_node_feat(node_type, node_ids, feature_names);
197+
std::vector<std::vector<py::bytes> > bytes_feats(feats.size());
198+
for (int i = 0; i < feats.size(); ++i ){
199+
for (int j = 0; j < feats[i].size(); ++j ){
200+
bytes_feats[i].push_back(py::bytes(feats[i][j]));
201+
}
202+
}
203+
return bytes_feats;
204+
})
195205
.def("bind_local_server", &GraphPyClient::bind_local_server);
196206
}
197207

0 commit comments

Comments
 (0)