Skip to content

Commit 0a4a081

Browse files
authored
Merge pull request #19 from Liwb5/engine2.0
fix bug of random sample k
2 parents fc74f83 + d743606 commit 0a4a081

File tree

1 file changed

+20
-34
lines changed

1 file changed

+20
-34
lines changed

paddle/fluid/distributed/table/graph/graph_weighted_sampler.cc

Lines changed: 20 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -24,47 +24,33 @@ void RandomSampler::build(GraphEdgeBlob *edges) { this->edges = edges; }
2424

2525
std::vector<int> RandomSampler::sample_k(int k, const std::shared_ptr<std::mt19937_64> rng) {
2626
int n = edges->size();
27-
if (k > n) {
27+
if (k >= n) {
2828
k = n;
29-
}
30-
std::vector<int> sample_result;
31-
for(int i = 0;i < k;i ++ ) {
29+
std::vector<int> sample_result;
30+
for (int i = 0; i < k; i++) {
3231
sample_result.push_back(i);
32+
}
33+
return sample_result;
3334
}
34-
if (k == n) {
35-
return sample_result;
36-
}
37-
38-
std::uniform_int_distribution<int> distrib(0, n - 1);
35+
std::vector<int> sample_result;
3936
std::unordered_map<int, int> replace_map;
37+
while (k--) {
38+
std::uniform_int_distribution<int> distrib(0, n - 1);
39+
int rand_int = distrib(*rng);
40+
auto iter = replace_map.find(rand_int);
41+
if (iter == replace_map.end()) {
42+
sample_result.push_back(rand_int);
43+
} else {
44+
sample_result.push_back(iter->second);
45+
}
4046

41-
for(int i = 0; i < k; i ++) {
42-
int j = distrib(*rng);
43-
if (j >= i) {
44-
// buff_nid[offset + i] = nid[j] if m.find(j) == m.end() else nid[m[j]]
45-
auto iter_j = replace_map.find(j);
46-
if(iter_j == replace_map.end()) {
47-
sample_result[i] = j;
48-
} else {
49-
sample_result[i] = iter_j -> second;
50-
}
51-
// m[j] = i if m.find(i) == m.end() else m[i]
52-
auto iter_i = replace_map.find(i);
53-
if(iter_i == replace_map.end()) {
54-
replace_map[j] = i;
55-
} else {
56-
replace_map[j] = (iter_i -> second);
57-
}
47+
iter = replace_map.find(n - 1);
48+
if (iter == replace_map.end()) {
49+
replace_map[rand_int] = n - 1;
5850
} else {
59-
sample_result[i] = sample_result[j];
60-
// buff_nid[offset + j] = nid[i] if m.find(i) == m.end() else nid[m[i]]
61-
auto iter_i = replace_map.find(i);
62-
if(iter_i == replace_map.end()) {
63-
sample_result[j] = i;
64-
} else {
65-
sample_result[j] = (iter_i -> second);
66-
}
51+
replace_map[rand_int] = iter->second;
6752
}
53+
--n;
6854
}
6955
return sample_result;
7056
}

0 commit comments

Comments
 (0)