Skip to content

Commit 6087b28

Browse files
committed
fix code style
1 parent 4cfea1f commit 6087b28

File tree

3 files changed

+40
-36
lines changed

3 files changed

+40
-36
lines changed

paddle/fluid/distributed/table/common_graph_table.cc

Lines changed: 27 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@
1515
#include "paddle/fluid/distributed/table/common_graph_table.h"
1616
#include <time.h>
1717
#include <algorithm>
18+
#include <chrono>
1819
#include <set>
1920
#include <sstream>
2021
#include "paddle/fluid/distributed/common/utils.h"
2122
#include "paddle/fluid/distributed/table/graph/graph_node.h"
23+
#include "paddle/fluid/framework/generator.h"
2224
#include "paddle/fluid/string/printf.h"
23-
#include <chrono>
2425
#include "paddle/fluid/string/string_helper.h"
25-
#include "paddle/fluid/framework/generator.h"
2626

2727
namespace paddle {
2828
namespace distributed {
@@ -406,31 +406,30 @@ int32_t GraphTable::random_sample_neighboors(
406406
int thread_pool_index = get_thread_pool_index(node_id);
407407
auto rng = _shards_task_rng_pool[thread_pool_index];
408408

409-
tasks.push_back(_shards_task_pool[thread_pool_index]->enqueue(
410-
[&]() -> int {
411-
Node *node = find_node(node_id);
409+
tasks.push_back(_shards_task_pool[thread_pool_index]->enqueue([&]() -> int {
410+
Node *node = find_node(node_id);
412411

413-
if (node == nullptr) {
414-
actual_size = 0;
415-
return 0;
416-
}
417-
std::vector<int> res = node->sample_k(sample_size, rng);
418-
actual_size = res.size() * (Node::id_size + Node::weight_size);
419-
int offset = 0;
420-
uint64_t id;
421-
float weight;
422-
char *buffer_addr = new char[actual_size];
423-
buffer.reset(buffer_addr);
424-
for (int &x : res) {
425-
id = node->get_neighbor_id(x);
426-
weight = node->get_neighbor_weight(x);
427-
memcpy(buffer_addr + offset, &id, Node::id_size);
428-
offset += Node::id_size;
429-
memcpy(buffer_addr + offset, &weight, Node::weight_size);
430-
offset += Node::weight_size;
431-
}
432-
return 0;
433-
}));
412+
if (node == nullptr) {
413+
actual_size = 0;
414+
return 0;
415+
}
416+
std::vector<int> res = node->sample_k(sample_size, rng);
417+
actual_size = res.size() * (Node::id_size + Node::weight_size);
418+
int offset = 0;
419+
uint64_t id;
420+
float weight;
421+
char *buffer_addr = new char[actual_size];
422+
buffer.reset(buffer_addr);
423+
for (int &x : res) {
424+
id = node->get_neighbor_id(x);
425+
weight = node->get_neighbor_weight(x);
426+
memcpy(buffer_addr + offset, &id, Node::id_size);
427+
offset += Node::id_size;
428+
memcpy(buffer_addr + offset, &weight, Node::weight_size);
429+
offset += Node::weight_size;
430+
}
431+
return 0;
432+
}));
434433
}
435434
for (size_t idx = 0; idx < node_num; ++idx) {
436435
tasks[idx].get();
@@ -519,7 +518,6 @@ int32_t GraphTable::pull_graph_list(int start, int total_size,
519518
int end = start + (count - 1) * step + 1;
520519
tasks.push_back(_shards_task_pool[i % task_pool_size_]->enqueue(
521520
[this, i, start, end, step, size]() -> std::vector<Node *> {
522-
523521
return this->shards[i].get_batch(start - size, end - size, step);
524522
}));
525523
start += count * step;
@@ -594,5 +592,5 @@ int32_t GraphTable::initialize() {
594592
shards = std::vector<GraphShard>(shard_num_per_table, GraphShard(shard_num));
595593
return 0;
596594
}
597-
}
598-
};
595+
} // namespace distributed
596+
}; // namespace paddle

paddle/fluid/distributed/table/graph/graph_node.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,5 +113,5 @@ void FeatureNode::recover_from_buffer(char* buffer) {
113113
feature.push_back(std::string(str));
114114
}
115115
}
116-
}
117-
}
116+
} // namespace distributed
117+
} // namespace paddle

paddle/fluid/distributed/table/graph/graph_node.h

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515
#pragma once
1616
#include <cstring>
1717
#include <iostream>
18+
#include <memory>
1819
#include <sstream>
1920
#include <vector>
2021
#include "paddle/fluid/distributed/table/graph/graph_weighted_sampler.h"
21-
#include <memory>
2222
namespace paddle {
2323
namespace distributed {
2424

@@ -34,7 +34,10 @@ class Node {
3434
virtual void build_edges(bool is_weighted) {}
3535
virtual void build_sampler(std::string sample_type) {}
3636
virtual void add_edge(uint64_t id, float weight) {}
37-
virtual std::vector<int> sample_k(int k, const std::shared_ptr<std::mt19937_64> rng) { return std::vector<int>(); }
37+
virtual std::vector<int> sample_k(
38+
int k, const std::shared_ptr<std::mt19937_64> rng) {
39+
return std::vector<int>();
40+
}
3841
virtual uint64_t get_neighbor_id(int idx) { return 0; }
3942
virtual float get_neighbor_weight(int idx) { return 1.; }
4043

@@ -60,7 +63,10 @@ class GraphNode : public Node {
6063
virtual void add_edge(uint64_t id, float weight) {
6164
edges->add_edge(id, weight);
6265
}
63-
virtual std::vector<int> sample_k(int k, const std::shared_ptr<std::mt19937_64> rng) { return sampler->sample_k(k, rng); }
66+
virtual std::vector<int> sample_k(
67+
int k, const std::shared_ptr<std::mt19937_64> rng) {
68+
return sampler->sample_k(k, rng);
69+
}
6470
virtual uint64_t get_neighbor_id(int idx) { return edges->get_id(idx); }
6571
virtual float get_neighbor_weight(int idx) { return edges->get_weight(idx); }
6672

@@ -124,5 +130,5 @@ class FeatureNode : public Node {
124130
protected:
125131
std::vector<std::string> feature;
126132
};
127-
}
128-
}
133+
} // namespace distributed
134+
} // namespace paddle

0 commit comments

Comments
 (0)