Skip to content

Commit d85fda7

Browse files
mdouzefacebook-github-bot
authored andcommitted
Allow k and M suffixes in IVF indexes (#3812)
Summary: Pull Request resolved: #3812 Allows factory strings like `IVF3k,Flat` as a shorthand for 3072 centroids. The main question is whether k or M should be metric (k=1000) or power of 2 (k=1024): * pro-metric: standard, * pro-power of 2: in practice we use powers of 2 most often The suffixes ki and Mi should be used for powers of 2 but this makes the notation more heavy (which is what we wanted to avoid in the first place). So I picked power of 2. Reviewed By: mnorris11 Differential Revision: D62019941 fbshipit-source-id: f547962625123ecdfaa406067781c77386017793
1 parent 6fe4640 commit d85fda7

File tree

2 files changed

+33
-8
lines changed

2 files changed

+33
-8
lines changed

faiss/index_factory.cpp

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,19 @@ VectorTransform* parse_VectorTransform(const std::string& description, int d) {
226226
* Parse IndexIVF
227227
*/
228228

229+
size_t parse_nlist(std::string s) {
230+
size_t multiplier = 1;
231+
if (s.back() == 'k') {
232+
s.pop_back();
233+
multiplier = 1024;
234+
}
235+
if (s.back() == 'M') {
236+
s.pop_back();
237+
multiplier = 1024 * 1024;
238+
}
239+
return std::stoi(s) * multiplier;
240+
}
241+
229242
// parsing guard + function
230243
Index* parse_coarse_quantizer(
231244
const std::string& description,
@@ -240,8 +253,8 @@ Index* parse_coarse_quantizer(
240253
};
241254
use_2layer = false;
242255

243-
if (match("IVF([0-9]+)")) {
244-
nlist = std::stoi(sm[1].str());
256+
if (match("IVF([0-9]+[kM]?)")) {
257+
nlist = parse_nlist(sm[1].str());
245258
return new IndexFlat(d, mt);
246259
}
247260
if (match("IMI2x([0-9]+)")) {
@@ -252,18 +265,18 @@ Index* parse_coarse_quantizer(
252265
nlist = (size_t)1 << (2 * nbit);
253266
return new MultiIndexQuantizer(d, 2, nbit);
254267
}
255-
if (match("IVF([0-9]+)_HNSW([0-9]*)")) {
256-
nlist = std::stoi(sm[1].str());
268+
if (match("IVF([0-9]+[kM]?)_HNSW([0-9]*)")) {
269+
nlist = parse_nlist(sm[1].str());
257270
int hnsw_M = sm[2].length() > 0 ? std::stoi(sm[2]) : 32;
258271
return new IndexHNSWFlat(d, hnsw_M, mt);
259272
}
260-
if (match("IVF([0-9]+)_NSG([0-9]+)")) {
261-
nlist = std::stoi(sm[1].str());
273+
if (match("IVF([0-9]+[kM]?)_NSG([0-9]+)")) {
274+
nlist = parse_nlist(sm[1].str());
262275
int R = std::stoi(sm[2]);
263276
return new IndexNSGFlat(d, R, mt);
264277
}
265-
if (match("IVF([0-9]+)\\(Index([0-9])\\)")) {
266-
nlist = std::stoi(sm[1].str());
278+
if (match("IVF([0-9]+[kM]?)\\(Index([0-9])\\)")) {
279+
nlist = parse_nlist(sm[1].str());
267280
int no = std::stoi(sm[2].str());
268281
FAISS_ASSERT(no >= 0 && no < parenthesis_indexes.size());
269282
return parenthesis_indexes[no].release();

tests/test_factory.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,18 @@ def test_ivf(self):
238238
index = faiss.index_factory(123, "IVF456,Flat")
239239
self.assertEqual(index.__class__, faiss.IndexIVFFlat)
240240

241+
def test_ivf_suffix_k(self):
242+
index = faiss.index_factory(123, "IVF3k,Flat")
243+
self.assertEqual(index.nlist, 3072)
244+
245+
def test_ivf_suffix_M(self):
246+
index = faiss.index_factory(123, "IVF1M,Flat")
247+
self.assertEqual(index.nlist, 1024 * 1024)
248+
249+
def test_ivf_suffix_HNSW_M(self):
250+
index = faiss.index_factory(123, "IVF1M_HNSW,Flat")
251+
self.assertEqual(index.nlist, 1024 * 1024)
252+
241253
def test_idmap(self):
242254
index = faiss.index_factory(123, "Flat,IDMap")
243255
self.assertEqual(index.__class__, faiss.IndexIDMap)

0 commit comments

Comments
 (0)