Skip to content

Commit 5f6c168

Browse files
authored
Merge pull request #8 from Yelrose/develop
Fixed pull_graph_list bug; add test case for pull_graph_list by step
2 parents 09667d1 + f861faa commit 5f6c168

File tree

2 files changed

+26
-4
lines changed

2 files changed

+26
-4
lines changed

paddle/fluid/distributed/service/graph_brpc_server.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ int32_t GraphBrpcService::pull_graph_list(Table *table,
270270
int step = *(int *)(request.params(2).c_str());
271271
std::unique_ptr<char[]> buffer;
272272
int actual_size;
273-
table->pull_graph_list(start, size, buffer, actual_size, step, true);
273+
table->pull_graph_list(start, size, buffer, actual_size, true, step);
274274
cntl->response_attachment().append(buffer.get(), actual_size);
275275
return 0;
276276
}

paddle/fluid/distributed/test/graph_node_test.cc

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,9 +142,9 @@ std::string nodes[] = {
142142
std::string("item\t45\t0.21"), std::string("item\t145\t0.21"),
143143
std::string("item\t112\t0.21"), std::string("item\t48\t0.21"),
144144
std::string("item\t247\t0.21"), std::string("item\t111\t0.21"),
145-
std::string("item\t45\t0.21"), std::string("item\t145\t0.21"),
146-
std::string("item\t122\t0.21"), std::string("item\t48\t0.21"),
147-
std::string("item\t247\t0.21"), std::string("item\t111\t0.21")};
145+
std::string("item\t46\t0.21"), std::string("item\t146\t0.21"),
146+
std::string("item\t122\t0.21"), std::string("item\t49\t0.21"),
147+
std::string("item\t248\t0.21"), std::string("item\t113\t0.21")};
148148
char node_file_name[] = "nodes.txt";
149149

150150
void prepare_file(char file_name[], bool load_edge) {
@@ -373,6 +373,7 @@ void RunBrpcPushSparse() {
373373
// client2.load_edge_file(std::string("user2item"), std::string(file_name),
374374
// 0);
375375
nodes.clear();
376+
376377
nodes = client1.pull_graph_list(std::string("user"), 0, 1, 4, 1);
377378

378379
for (auto g : nodes) {
@@ -381,6 +382,27 @@ void RunBrpcPushSparse() {
381382
std::cout << "node_ids: " << nodes[0].get_id() << std::endl;
382383
ASSERT_EQ(nodes[0].get_id(), 59);
383384
nodes.clear();
385+
386+
// Test Pull by step
387+
388+
389+
std::unordered_set<uint64_t> count_item_nodes;
390+
// pull by step 2
391+
for(int test_step=1; test_step < 4 ; test_step ++) {
392+
count_item_nodes.clear();
393+
std::cout << "check pull graph list by step " << test_step << std::endl;
394+
for(int server_id = 0; server_id < 2; server_id ++) {
395+
for(int start_step = 0; start_step < test_step; start_step ++) {
396+
nodes = client1.pull_graph_list(std::string("item"), server_id, start_step, 12, test_step);
397+
for (auto g : nodes) {
398+
count_item_nodes.insert(g.get_id());
399+
}
400+
nodes.clear();
401+
}
402+
}
403+
ASSERT_EQ(count_item_nodes.size(), 12);
404+
}
405+
384406
vs = client1.batch_sample_neighboors(std::string("user2item"),
385407
std::vector<uint64_t>(1, 96), 4);
386408
ASSERT_EQ(vs[0].size(), 3);

0 commit comments

Comments
 (0)