Skip to content

Commit 9950d7e

Browse files
Mengdi Linfacebook-github-bot
authored andcommitted
fix integer overflow issue when calculating imbalance_factor (facebookresearch#4245)
Summary: Pull Request resolved: facebookresearch#4245 When number of clustering embeddings > int32 max, calculating imbalance factor from python side causes an function overload not found error. ``` [0]:[rank0]: return faiss.imbalance_factor(len(assign), k, faiss.swig_ptr(assign)) [0]:[rank0]: NotImplementedError: Wrong number or type of arguments for overloaded function 'imbalance_factor'. [0]:[rank0]: Possible C/C++ prototypes are: [0]:[rank0]: faiss::imbalance_factor(int,int,int64_t const *) [0]:[rank0]: faiss::imbalance_factor(int,int const *) ``` Fixing it by changing the function signature in c++ land to support int64_t. Reviewed By: bshethmeta Differential Revision: D71130612 fbshipit-source-id: becbf464a9d3979269cc7f0cecc6b80a6f9e1199
1 parent 0d7b7ea commit 9950d7e

4 files changed

Lines changed: 7 additions & 23 deletions

File tree

faiss/Clustering.cpp

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -33,22 +33,6 @@ Clustering::Clustering(int d, int k) : d(d), k(k) {}
3333
Clustering::Clustering(int d, int k, const ClusteringParameters& cp)
3434
: ClusteringParameters(cp), d(d), k(k) {}
3535

36-
static double imbalance_factor(int n, int k, int64_t* assign) {
37-
std::vector<int> hist(k, 0);
38-
for (int i = 0; i < n; i++)
39-
hist[assign[i]]++;
40-
41-
double tot = 0, uf = 0;
42-
43-
for (int i = 0; i < k; i++) {
44-
tot += hist[i];
45-
uf += hist[i] * (double)hist[i];
46-
}
47-
uf = uf * k / (tot * tot);
48-
49-
return uf;
50-
}
51-
5236
void Clustering::post_process_centroids() {
5337
if (spherical) {
5438
fvec_renorm_L2(d, k, centroids.data());

faiss/invlists/InvertedLists.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ size_t InvertedLists::copy_subset_to(
181181
}
182182

183183
double InvertedLists::imbalance_factor() const {
184-
std::vector<int> hist(nlist);
184+
std::vector<int64_t> hist(nlist);
185185

186186
for (size_t i = 0; i < nlist; i++) {
187187
hist[i] = list_size(i);

faiss/utils/utils.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ size_t ranklist_intersection_size(
387387
return count;
388388
}
389389

390-
double imbalance_factor(int k, const int* hist) {
390+
double imbalance_factor(int k, const int64_t* hist) {
391391
double tot = 0, uf = 0;
392392

393393
for (int i = 0; i < k; i++) {
@@ -399,9 +399,9 @@ double imbalance_factor(int k, const int* hist) {
399399
return uf;
400400
}
401401

402-
double imbalance_factor(int n, int k, const int64_t* assign) {
403-
std::vector<int> hist(k, 0);
404-
for (int i = 0; i < n; i++) {
402+
double imbalance_factor(int64_t n, int k, const int64_t* assign) {
403+
std::vector<int64_t> hist(k, 0);
404+
for (int64_t i = 0; i < n; i++) {
405405
hist[assign[i]]++;
406406
}
407407

faiss/utils/utils.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,10 @@ size_t merge_result_table_with(
9292

9393
/// a balanced assignment has a IF of 1, a completely unbalanced assignment has
9494
/// an IF = k.
95-
double imbalance_factor(int n, int k, const int64_t* assign);
95+
double imbalance_factor(int64_t n, int k, const int64_t* assign);
9696

9797
/// same, takes a histogram as input
98-
double imbalance_factor(int k, const int* hist);
98+
double imbalance_factor(int k, const int64_t* hist);
9999

100100
/// compute histogram on v
101101
int ivec_hist(size_t n, const int* v, int vmax, int* hist);

0 commit comments

Comments
 (0)