Skip to content

Commit 46dc17f

Browse files
authored
Merge pull request #5 from WeiyueSu/sample
sample with srand
2 parents 6627904 + ec2555a commit 46dc17f

File tree

5 files changed

+112
-18
lines changed

5 files changed

+112
-18
lines changed

paddle/fluid/distributed/table/common_graph_table.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ int32_t GraphTable::load_nodes(const std::string &path, std::string node_type) {
146146
int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) {
147147
auto paths = paddle::string::split_string<std::string>(path, ";");
148148
int count = 0;
149+
std::string sample_type = "random";
149150

150151
for (auto path : paths) {
151152
std::ifstream file(path);
@@ -159,9 +160,10 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) {
159160
if (reverse_edge) {
160161
std::swap(src_id, dst_id);
161162
}
162-
float weight = 0;
163+
float weight = 1;
163164
if (values.size() == 3) {
164165
weight = std::stof(values[2]);
166+
sample_type = "weighted";
165167
}
166168

167169
size_t src_shard_id = src_id % shard_num;
@@ -184,8 +186,7 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) {
184186
for (auto &shard : shards) {
185187
auto bucket = shard.get_bucket();
186188
for (int i = 0; i < bucket.size(); i++) {
187-
bucket[i]->build_sampler();
188-
}
189+
bucket[i]->build_sampler(sample_type); }
189190
}
190191
return 0;
191192
}

paddle/fluid/distributed/table/graph_node.cc

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,15 @@ int GraphNode::int_size = sizeof(int);
2222
int GraphNode::get_size(bool need_feature) {
2323
return id_size + int_size + (need_feature ? feature.size() : 0);
2424
}
25-
void GraphNode::build_sampler() {
26-
sampler = new WeightedSampler();
27-
GraphEdge** arr = edges.data();
28-
sampler->build((WeightedObject**)arr, 0, edges.size());
25+
void GraphNode::build_sampler(std::string sample_type) {
26+
if (sample_type == "random"){
27+
sampler = new RandomSampler();
28+
} else if (sample_type == "weighted"){
29+
sampler = new WeightedSampler();
30+
}
31+
//GraphEdge** arr = edges.data();
32+
//sampler->build((WeightedObject**)arr, 0, edges.size());
33+
sampler->build((std::vector<WeightedObject*>*)&edges);
2934
}
3035
void GraphNode::to_buffer(char* buffer, bool need_feature) {
3136
int size = get_size(need_feature);
@@ -51,4 +56,4 @@ void GraphNode::recover_from_buffer(char* buffer) {
5156
// type = GraphNodeType(int_state);
5257
}
5358
}
54-
}
59+
}

paddle/fluid/distributed/table/graph_node.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class GraphNode {
4040
void set_feature(std::string feature) { this->feature = feature; }
4141
std::string get_feature() { return feature; }
4242
virtual int get_size(bool need_feature);
43-
virtual void build_sampler();
43+
virtual void build_sampler(std::string sample_type);
4444
virtual void to_buffer(char *buffer, bool need_feature);
4545
virtual void recover_from_buffer(char *buffer);
4646
virtual void add_edge(GraphEdge *edge) { edges.push_back(edge); }
@@ -58,7 +58,7 @@ class GraphNode {
5858
protected:
5959
uint64_t id;
6060
std::string feature;
61-
WeightedSampler *sampler;
61+
Sampler *sampler;
6262
std::vector<GraphEdge *> edges;
6363
};
6464
}

paddle/fluid/distributed/table/weighted_sampler.cc

Lines changed: 74 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,87 @@
1414

1515
#include "paddle/fluid/distributed/table/weighted_sampler.h"
1616
#include <iostream>
17+
#include<unordered_map>
1718
namespace paddle {
1819
namespace distributed {
19-
void WeightedSampler::build(WeightedObject **v, int start, int end) {
20+
21+
void RandomSampler::build(std::vector<WeightedObject*>* edges) {
22+
this->edges = edges;
23+
}
24+
25+
std::vector<WeightedObject *> RandomSampler::sample_k(int k) {
26+
int n = edges->size();
27+
if (k > n){
28+
k = n;
29+
}
30+
struct timespec tn;
31+
clock_gettime(CLOCK_REALTIME, &tn);
32+
srand(tn.tv_nsec);
33+
std::vector<WeightedObject *> sample_result;
34+
std::unordered_map<int, int> replace_map;
35+
while(k--){
36+
int rand_int = rand() % n;
37+
auto iter = replace_map.find(rand_int);
38+
if(iter == replace_map.end()){
39+
sample_result.push_back(edges->at(rand_int));
40+
}else{
41+
sample_result.push_back(edges->at(iter->second));
42+
}
43+
44+
iter = replace_map.find(n - 1);
45+
if(iter == replace_map.end()){
46+
replace_map[rand_int] = n - 1;
47+
}else{
48+
replace_map[rand_int] = iter->second;
49+
}
50+
--n;
51+
}
52+
return sample_result;
53+
}
54+
55+
WeightedSampler::WeightedSampler(){
56+
left = nullptr;
57+
right = nullptr;
58+
object = nullptr;
59+
}
60+
61+
WeightedSampler::~WeightedSampler() {
62+
if(left != nullptr){
63+
delete left;
64+
left = nullptr;
65+
}
66+
if(right != nullptr){
67+
delete right;
68+
right = nullptr;
69+
}
70+
}
71+
72+
void WeightedSampler::build(std::vector<WeightedObject*>* edges) {
73+
if(left != nullptr){
74+
delete left;
75+
left = nullptr;
76+
}
77+
if(right != nullptr){
78+
delete right;
79+
right = nullptr;
80+
}
81+
WeightedObject** v = edges->data();
82+
return build_one(v, 0, edges->size());
83+
}
84+
85+
void WeightedSampler::build_one(WeightedObject **v, int start, int end) {
2086
count = 0;
2187
if (start + 1 == end) {
22-
left = right = NULL;
88+
left = right = nullptr;
2389
weight = v[start]->get_weight();
2490
object = v[start];
2591
count = 1;
2692

2793
} else {
2894
left = new WeightedSampler();
2995
right = new WeightedSampler();
30-
left->build(v, start, start + (end - start) / 2);
31-
right->build(v, start + (end - start) / 2, end);
96+
left->build_one(v, start, start + (end - start) / 2);
97+
right->build_one(v, start + (end - start) / 2, end);
3298
weight = left->weight + right->weight;
3399
count = left->count + right->count;
34100
}
@@ -41,6 +107,9 @@ std::vector<WeightedObject *> WeightedSampler::sample_k(int k) {
41107
float subtract;
42108
std::unordered_map<WeightedSampler *, float> subtract_weight_map;
43109
std::unordered_map<WeightedSampler *, int> subtract_count_map;
110+
struct timespec tn;
111+
clock_gettime(CLOCK_REALTIME, &tn);
112+
srand(tn.tv_nsec);
44113
while (k--) {
45114
float query_weight = rand() % 100000 / 100000.0;
46115
query_weight *= weight - subtract_weight_map[this];
@@ -54,7 +123,7 @@ WeightedObject *WeightedSampler::sample(
54123
std::unordered_map<WeightedSampler *, float> &subtract_weight_map,
55124
std::unordered_map<WeightedSampler *, int> &subtract_count_map,
56125
float &subtract) {
57-
if (left == NULL) {
126+
if (left == nullptr) {
58127
subtract_weight_map[this] = weight;
59128
subtract = weight;
60129
subtract_count_map[this] = 1;

paddle/fluid/distributed/table/weighted_sampler.h

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <vector>
1919
namespace paddle {
2020
namespace distributed {
21+
2122
class WeightedObject {
2223
public:
2324
WeightedObject() {}
@@ -26,14 +27,32 @@ class WeightedObject {
2627
virtual float get_weight() = 0;
2728
};
2829

29-
class WeightedSampler {
30+
class Sampler {
31+
public:
32+
virtual ~Sampler() {}
33+
virtual void build(std::vector<WeightedObject*>* edges) = 0;
34+
virtual std::vector<WeightedObject *> sample_k(int k) = 0;
35+
};
36+
37+
class RandomSampler: public Sampler {
38+
public:
39+
virtual ~RandomSampler() {}
40+
virtual void build(std::vector<WeightedObject*>* edges);
41+
virtual std::vector<WeightedObject *> sample_k(int k);
42+
std::vector<WeightedObject*>* edges;
43+
};
44+
45+
class WeightedSampler: public Sampler {
3046
public:
47+
WeightedSampler();
48+
virtual ~WeightedSampler();
3149
WeightedSampler *left, *right;
3250
WeightedObject *object;
3351
int count;
3452
float weight;
35-
void build(WeightedObject **v, int start, int end);
36-
std::vector<WeightedObject *> sample_k(int k);
53+
virtual void build(std::vector<WeightedObject*>* edges);
54+
virtual void build_one(WeightedObject **v, int start, int end);
55+
virtual std::vector<WeightedObject *> sample_k(int k);
3756

3857
private:
3958
WeightedObject *sample(

0 commit comments

Comments
 (0)