From f861faaadae582e6a2bb58bf55a4c0378820924b Mon Sep 17 00:00:00 2001 From: Yelrose <270018958@qq.com> Date: Mon, 22 Mar 2021 14:46:31 +0800 Subject: [PATCH] fixed pull_graph_list bug; add test for pull_graph_list by step --- .../distributed/service/graph_brpc_server.cc | 2 +- .../fluid/distributed/test/graph_node_test.cc | 28 +++++++++++++++++-- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/distributed/service/graph_brpc_server.cc b/paddle/fluid/distributed/service/graph_brpc_server.cc index 60d6bc203a0742..765c4e9254254f 100644 --- a/paddle/fluid/distributed/service/graph_brpc_server.cc +++ b/paddle/fluid/distributed/service/graph_brpc_server.cc @@ -270,7 +270,7 @@ int32_t GraphBrpcService::pull_graph_list(Table *table, int step = *(int *)(request.params(2).c_str()); std::unique_ptr buffer; int actual_size; - table->pull_graph_list(start, size, buffer, actual_size, step, true); + table->pull_graph_list(start, size, buffer, actual_size, true, step); cntl->response_attachment().append(buffer.get(), actual_size); return 0; } diff --git a/paddle/fluid/distributed/test/graph_node_test.cc b/paddle/fluid/distributed/test/graph_node_test.cc index efee0d9441ef2e..2ba5946cc443f5 100644 --- a/paddle/fluid/distributed/test/graph_node_test.cc +++ b/paddle/fluid/distributed/test/graph_node_test.cc @@ -142,9 +142,9 @@ std::string nodes[] = { std::string("item\t45\t0.21"), std::string("item\t145\t0.21"), std::string("item\t112\t0.21"), std::string("item\t48\t0.21"), std::string("item\t247\t0.21"), std::string("item\t111\t0.21"), - std::string("item\t45\t0.21"), std::string("item\t145\t0.21"), - std::string("item\t122\t0.21"), std::string("item\t48\t0.21"), - std::string("item\t247\t0.21"), std::string("item\t111\t0.21")}; + std::string("item\t46\t0.21"), std::string("item\t146\t0.21"), + std::string("item\t122\t0.21"), std::string("item\t49\t0.21"), + std::string("item\t248\t0.21"), std::string("item\t113\t0.21")}; char node_file_name[] = "nodes.txt"; void prepare_file(char file_name[], bool load_edge) { @@ -373,6 +373,7 @@ void RunBrpcPushSparse() { // client2.load_edge_file(std::string("user2item"), std::string(file_name), // 0); nodes.clear(); + nodes = client1.pull_graph_list(std::string("user"), 0, 1, 4, 1); for (auto g : nodes) { @@ -381,6 +382,27 @@ void RunBrpcPushSparse() { std::cout << "node_ids: " << nodes[0].get_id() << std::endl; ASSERT_EQ(nodes[0].get_id(), 59); nodes.clear(); + + // Test Pull by step + + + std::unordered_set count_item_nodes; + // pull by step 2 + for(int test_step=1; test_step < 4 ; test_step ++) { + count_item_nodes.clear(); + std::cout << "check pull graph list by step " << test_step << std::endl; + for(int server_id = 0; server_id < 2; server_id ++) { + for(int start_step = 0; start_step < test_step; start_step ++) { + nodes = client1.pull_graph_list(std::string("item"), server_id, start_step, 12, test_step); + for (auto g : nodes) { + count_item_nodes.insert(g.get_id()); + } + nodes.clear(); + } + } + ASSERT_EQ(count_item_nodes.size(), 12); + } + vs = client1.batch_sample_neighboors(std::string("user2item"), std::vector(1, 96), 4); ASSERT_EQ(vs[0].size(), 3);