|
15 | 15 | #include "paddle/fluid/distributed/table/common_graph_table.h" |
16 | 16 | #include <time.h> |
17 | 17 | #include <algorithm> |
| 18 | +#include <chrono> |
18 | 19 | #include <set> |
19 | 20 | #include <sstream> |
20 | 21 | #include "paddle/fluid/distributed/common/utils.h" |
21 | 22 | #include "paddle/fluid/distributed/table/graph/graph_node.h" |
| 23 | +#include "paddle/fluid/framework/generator.h" |
22 | 24 | #include "paddle/fluid/string/printf.h" |
23 | | -#include <chrono> |
24 | 25 | #include "paddle/fluid/string/string_helper.h" |
25 | | -#include "paddle/fluid/framework/generator.h" |
26 | 26 |
|
27 | 27 | namespace paddle { |
28 | 28 | namespace distributed { |
@@ -406,31 +406,30 @@ int32_t GraphTable::random_sample_neighboors( |
406 | 406 | int thread_pool_index = get_thread_pool_index(node_id); |
407 | 407 | auto rng = _shards_task_rng_pool[thread_pool_index]; |
408 | 408 |
|
409 | | - tasks.push_back(_shards_task_pool[thread_pool_index]->enqueue( |
410 | | - [&]() -> int { |
411 | | - Node *node = find_node(node_id); |
| 409 | + tasks.push_back(_shards_task_pool[thread_pool_index]->enqueue([&]() -> int { |
| 410 | + Node *node = find_node(node_id); |
412 | 411 |
|
413 | | - if (node == nullptr) { |
414 | | - actual_size = 0; |
415 | | - return 0; |
416 | | - } |
417 | | - std::vector<int> res = node->sample_k(sample_size, rng); |
418 | | - actual_size = res.size() * (Node::id_size + Node::weight_size); |
419 | | - int offset = 0; |
420 | | - uint64_t id; |
421 | | - float weight; |
422 | | - char *buffer_addr = new char[actual_size]; |
423 | | - buffer.reset(buffer_addr); |
424 | | - for (int &x : res) { |
425 | | - id = node->get_neighbor_id(x); |
426 | | - weight = node->get_neighbor_weight(x); |
427 | | - memcpy(buffer_addr + offset, &id, Node::id_size); |
428 | | - offset += Node::id_size; |
429 | | - memcpy(buffer_addr + offset, &weight, Node::weight_size); |
430 | | - offset += Node::weight_size; |
431 | | - } |
432 | | - return 0; |
433 | | - })); |
| 412 | + if (node == nullptr) { |
| 413 | + actual_size = 0; |
| 414 | + return 0; |
| 415 | + } |
| 416 | + std::vector<int> res = node->sample_k(sample_size, rng); |
| 417 | + actual_size = res.size() * (Node::id_size + Node::weight_size); |
| 418 | + int offset = 0; |
| 419 | + uint64_t id; |
| 420 | + float weight; |
| 421 | + char *buffer_addr = new char[actual_size]; |
| 422 | + buffer.reset(buffer_addr); |
| 423 | + for (int &x : res) { |
| 424 | + id = node->get_neighbor_id(x); |
| 425 | + weight = node->get_neighbor_weight(x); |
| 426 | + memcpy(buffer_addr + offset, &id, Node::id_size); |
| 427 | + offset += Node::id_size; |
| 428 | + memcpy(buffer_addr + offset, &weight, Node::weight_size); |
| 429 | + offset += Node::weight_size; |
| 430 | + } |
| 431 | + return 0; |
| 432 | + })); |
434 | 433 | } |
435 | 434 | for (size_t idx = 0; idx < node_num; ++idx) { |
436 | 435 | tasks[idx].get(); |
@@ -519,7 +518,6 @@ int32_t GraphTable::pull_graph_list(int start, int total_size, |
519 | 518 | int end = start + (count - 1) * step + 1; |
520 | 519 | tasks.push_back(_shards_task_pool[i % task_pool_size_]->enqueue( |
521 | 520 | [this, i, start, end, step, size]() -> std::vector<Node *> { |
522 | | - |
523 | 521 | return this->shards[i].get_batch(start - size, end - size, step); |
524 | 522 | })); |
525 | 523 | start += count * step; |
@@ -594,5 +592,5 @@ int32_t GraphTable::initialize() { |
594 | 592 | shards = std::vector<GraphShard>(shard_num_per_table, GraphShard(shard_num)); |
595 | 593 | return 0; |
596 | 594 | } |
597 | | -} |
598 | | -}; |
| 595 | +} // namespace distributed |
| 596 | +}; // namespace paddle |
0 commit comments