Skip to content

Commit c1dc753

Browse files
RaBitQ implementation (facebookresearch#4235)
Summary: This is a reference implementation of the https://arxiv.org/pdf/2405.12497 > Jianyang Gao, Cheng Long, "RaBitQ: Quantizing High-Dimensional Vectors with a Theoretical Error Bound for Approximate Nearest Neighbor Search". The goal is to correctly set up the internals using Faiss. The following comments for the implementation: * The code does not include the computations for the symmetric distance, because it is absent in the original article. This can be added later, though. * The original `RaBitQ` includes random matrix rotation as a part of it, but I've decided to rely on external `faiss::IndexPreTransform` and `faiss::RandomRotationMatrix` facilities. * Certain features required internal changes in `faiss::IndexIVF`, but I did that as least invasive as possible, without breaking the backward compatibility. * Not sure about naming convensions, maybe certain classes and structures need to be renamed * `METRIC_INNER_PRODUCT` is supported as well * More unit tests are needed? * I did not bring any hardware-specific optimizations, bcz this is a reference implementation. Certain `simdlib` facilities may be added later, if needed Here's how to use IndexRaBitQ ```Python ds = datasets.SyntheticDataset(...) index_rbq = faiss.IndexRaBitQ(ds.d, faiss.METRIC_L2) index_rbq.qb = 8 # wrap with random rotations rrot = faiss.RandomRotationMatrix(ds.d, ds.d) rrot.init(rrot_seed) index_cand = faiss.IndexPreTransform(rrot, index_rbq) index_cand.train(ds.get_train()) index_cand.add(ds.get_database()) ``` Here's how to use IndexIVFRaBitQ ```Python ds = datasets.SyntheticDataset(...) index_flat = faiss.IndexFlat(ds.d, faiss.METRIC_L2) index_rbq = faiss.IndexIVFRaBitQ(index_flat, ds.d, nlist, faiss.METRIC_L2) index_rbq.qb = 8 # wrap with random rotations rrot = faiss.RandomRotationMatrix(ds.d, ds.d) rrot.init(rrot_seed) index_cand = faiss.IndexPreTransform(rrot, index_rbq) index_cand.train(ds.get_train()) index_cand.add(ds.get_database()) ``` Pull Request resolved: facebookresearch#4235 Test Plan: Imported from GitHub, without a `Test Plan:` line. buck run 'fbcode//mode/dev' fbcode//faiss/tests:test_rabitq Reviewed By: mdouze Differential Revision: D71638302 Pulled By: junjieqi fbshipit-source-id: de981a6aed91d296237d8accf337359de04a552e
1 parent 6a29514 commit c1dc753

26 files changed

Lines changed: 1743 additions & 18 deletions

faiss/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ set(FAISS_SRC
2828
IndexIVFAdditiveQuantizerFastScan.cpp
2929
IndexIVFPQFastScan.cpp
3030
IndexIVFPQR.cpp
31+
IndexIVFRaBitQ.cpp
3132
IndexIVFSpectralHash.cpp
3233
IndexLSH.cpp
3334
IndexNNDescent.cpp
@@ -39,6 +40,7 @@ set(FAISS_SRC
3940
IndexIVFIndependentQuantizer.cpp
4041
IndexPQFastScan.cpp
4142
IndexPreTransform.cpp
43+
IndexRaBitQ.cpp
4244
IndexRefine.cpp
4345
IndexReplicas.cpp
4446
IndexRowwiseMinMax.cpp
@@ -60,6 +62,7 @@ set(FAISS_SRC
6062
impl/PolysemousTraining.cpp
6163
impl/ProductQuantizer.cpp
6264
impl/AdditiveQuantizer.cpp
65+
impl/RaBitQuantizer.cpp
6366
impl/ResidualQuantizer.cpp
6467
impl/LocalSearchQuantizer.cpp
6568
impl/ProductAdditiveQuantizer.cpp
@@ -123,6 +126,7 @@ set(FAISS_HEADERS
123126
IndexIVFAdditiveQuantizerFastScan.h
124127
IndexIVFPQFastScan.h
125128
IndexIVFPQR.h
129+
IndexIVFRaBitQ.h
126130
IndexIVFSpectralHash.h
127131
IndexLSH.h
128132
IndexNeuralNetCodec.h
@@ -136,6 +140,7 @@ set(FAISS_HEADERS
136140
IndexPreTransform.h
137141
IndexRefine.h
138142
IndexReplicas.h
143+
IndexRaBitQ.h
139144
IndexRowwiseMinMax.h
140145
IndexScalarQuantizer.h
141146
IndexShards.h
@@ -164,6 +169,7 @@ set(FAISS_HEADERS
164169
impl/ProductQuantizer-inl.h
165170
impl/ProductQuantizer.h
166171
impl/Quantizer.h
172+
impl/RaBitQuantizer.h
167173
impl/ResidualQuantizer.h
168174
impl/ResultHandler.h
169175
impl/ScalarQuantizer.h

faiss/IndexIVF.cpp

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ void IndexIVF::search_preassigned(
455455
#pragma omp parallel if (do_parallel) reduction(+ : nlistv, ndis, nheap)
456456
{
457457
std::unique_ptr<InvertedListScanner> scanner(
458-
get_InvertedListScanner(store_pairs, sel));
458+
get_InvertedListScanner(store_pairs, sel, params));
459459

460460
/*****************************************************
461461
* Depending on parallel_mode, there are two possible ways
@@ -796,7 +796,7 @@ void IndexIVF::range_search_preassigned(
796796
{
797797
RangeSearchPartialResult pres(result);
798798
std::unique_ptr<InvertedListScanner> scanner(
799-
get_InvertedListScanner(store_pairs, sel));
799+
get_InvertedListScanner(store_pairs, sel, params));
800800
FAISS_THROW_IF_NOT(scanner.get());
801801
all_pres[omp_get_thread_num()] = &pres;
802802

@@ -912,7 +912,8 @@ void IndexIVF::range_search_preassigned(
912912

913913
InvertedListScanner* IndexIVF::get_InvertedListScanner(
914914
bool /*store_pairs*/,
915-
const IDSelector* /* sel */) const {
915+
const IDSelector* /* sel */,
916+
const IVFSearchParameters* /* params */) const {
916917
FAISS_THROW_MSG("get_InvertedListScanner not implemented");
917918
}
918919

@@ -1290,6 +1291,14 @@ size_t InvertedListScanner::scan_codes(
12901291

12911292
if (!keep_max) {
12921293
for (size_t j = 0; j < list_size; j++) {
1294+
if (sel != nullptr) {
1295+
int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
1296+
if (!sel->is_member(id)) {
1297+
codes += code_size;
1298+
continue;
1299+
}
1300+
}
1301+
12931302
float dis = distance_to_code(codes);
12941303
if (dis < simi[0]) {
12951304
int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
@@ -1300,6 +1309,14 @@ size_t InvertedListScanner::scan_codes(
13001309
}
13011310
} else {
13021311
for (size_t j = 0; j < list_size; j++) {
1312+
if (sel != nullptr) {
1313+
int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
1314+
if (!sel->is_member(id)) {
1315+
codes += code_size;
1316+
continue;
1317+
}
1318+
}
1319+
13031320
float dis = distance_to_code(codes);
13041321
if (dis > simi[0]) {
13051322
int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];

faiss/IndexIVF.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -312,11 +312,14 @@ struct IndexIVF : Index, IndexIVFInterface {
312312

313313
/** Get a scanner for this index (store_pairs means ignore labels)
314314
*
315-
* The default search implementation uses this to compute the distances
315+
* The default search implementation uses this to compute the distances.
316+
* Use sel instead of params->sel, because sel is initialized with
317+
* params->sel, but may get overridden by IndexIVF's internal logic.
316318
*/
317319
virtual InvertedListScanner* get_InvertedListScanner(
318320
bool store_pairs = false,
319-
const IDSelector* sel = nullptr) const;
321+
const IDSelector* sel = nullptr,
322+
const IVFSearchParameters* params = nullptr) const;
320323

321324
/** reconstruct a vector. Works only if maintain_direct_map is set to 1 or 2
322325
*/

faiss/IndexIVFAdditiveQuantizer.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,8 @@ struct AQInvertedListScannerLUT : AQInvertedListScanner {
253253

254254
InvertedListScanner* IndexIVFAdditiveQuantizer::get_InvertedListScanner(
255255
bool store_pairs,
256-
const IDSelector* sel) const {
256+
const IDSelector* sel,
257+
const IVFSearchParameters*) const {
257258
FAISS_THROW_IF_NOT(!sel);
258259
if (metric_type == METRIC_INNER_PRODUCT) {
259260
if (aq->search_type == AdditiveQuantizer::ST_decompress) {

faiss/IndexIVFAdditiveQuantizer.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ struct IndexIVFAdditiveQuantizer : IndexIVF {
5252

5353
InvertedListScanner* get_InvertedListScanner(
5454
bool store_pairs,
55-
const IDSelector* sel) const override;
55+
const IDSelector* sel,
56+
const IVFSearchParameters* params) const override;
5657

5758
void sa_decode(idx_t n, const uint8_t* codes, float* x) const override;
5859

faiss/IndexIVFFlat.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,8 @@ InvertedListScanner* get_InvertedListScanner1(
223223

224224
InvertedListScanner* IndexIVFFlat::get_InvertedListScanner(
225225
bool store_pairs,
226-
const IDSelector* sel) const {
226+
const IDSelector* sel,
227+
const IVFSearchParameters*) const {
227228
if (sel) {
228229
return get_InvertedListScanner1<true>(this, store_pairs, sel);
229230
} else {

faiss/IndexIVFFlat.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ struct IndexIVFFlat : IndexIVF {
4444

4545
InvertedListScanner* get_InvertedListScanner(
4646
bool store_pairs,
47-
const IDSelector* sel) const override;
47+
const IDSelector* sel,
48+
const IVFSearchParameters* params) const override;
4849

4950
void reconstruct_from_offset(int64_t list_no, int64_t offset, float* recons)
5051
const override;

faiss/IndexIVFPQ.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1321,7 +1321,8 @@ InvertedListScanner* get_InvertedListScanner2(
13211321

13221322
InvertedListScanner* IndexIVFPQ::get_InvertedListScanner(
13231323
bool store_pairs,
1324-
const IDSelector* sel) const {
1324+
const IDSelector* sel,
1325+
const IVFSearchParameters*) const {
13251326
if (sel) {
13261327
return get_InvertedListScanner2<true>(*this, store_pairs, sel);
13271328
} else {

faiss/IndexIVFPQ.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,8 @@ struct IndexIVFPQ : IndexIVF {
134134

135135
InvertedListScanner* get_InvertedListScanner(
136136
bool store_pairs,
137-
const IDSelector* sel) const override;
137+
const IDSelector* sel,
138+
const IVFSearchParameters* params) const override;
138139

139140
/// build precomputed table
140141
void precompute_table();

0 commit comments

Comments
 (0)