Skip to content

Commit 8ee43af

Browse files
refactor to extra 2 fp32 values per vector, add unit tests
Signed-off-by: Alexandr Guzhva <[email protected]>
1 parent eda7764 commit 8ee43af

File tree

6 files changed

+158
-123
lines changed

6 files changed

+158
-123
lines changed

faiss/IndexIVF.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,8 @@ struct IndexIVF : Index, IndexIVFInterface {
313313
/** Get a scanner for this index (store_pairs means ignore labels)
314314
*
315315
* The default search implementation uses this to compute the distances.
316-
* Use sel instead of params->sel, because sel can get overriden.
316+
* Use sel instead of params->sel, because sel is initialized with
317+
* params->sel, but may get overriden by IndexIVF's internal logic.
317318
*/
318319
virtual InvertedListScanner* get_InvertedListScanner(
319320
bool store_pairs = false,

faiss/IndexIVFRaBitQ.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ IndexIVFRaBitQ::IndexIVFRaBitQ(
1717
const size_t d,
1818
const size_t nlist,
1919
MetricType metric)
20-
: IndexIVF(quantizer, d, nlist, 0, metric), rabitq(d) {
20+
: IndexIVF(quantizer, d, nlist, 0, metric), rabitq(d, metric) {
2121
code_size = rabitq.code_size;
2222
invlists->code_size = code_size;
2323
is_trained = false;
@@ -120,7 +120,7 @@ struct RaBitInvertedListScanner : InvertedListScanner {
120120
std::vector<float> reconstructed_centroid;
121121
std::vector<float> query_vector;
122122

123-
std::unique_ptr<RaBitQuantizer::RaBitDistanceComputer> dc;
123+
std::unique_ptr<FlatCodesDistanceComputer> dc;
124124

125125
uint8_t qb = 0;
126126

@@ -164,7 +164,7 @@ struct RaBitInvertedListScanner : InvertedListScanner {
164164
// both query_vector and centroid are available!
165165
// set up DistanceComputer
166166
dc.reset(ivf_rabitq.rabitq.get_distance_computer(
167-
qb, ivf_rabitq.metric_type, reconstructed_centroid.data()));
167+
qb, reconstructed_centroid.data()));
168168

169169
dc->set_query(query_vector.data());
170170
}
@@ -249,9 +249,8 @@ float IVFRaBitDistanceComputer::operator()(idx_t i) {
249249
// compute the distance
250250
float distance = 0;
251251

252-
std::unique_ptr<RaBitQuantizer::RaBitDistanceComputer> dc(
253-
parent->rabitq.get_distance_computer(
254-
parent->qb, parent->metric_type, centroid.data()));
252+
std::unique_ptr<FlatCodesDistanceComputer> dc(
253+
parent->rabitq.get_distance_computer(parent->qb, centroid.data()));
255254
dc->set_query(q);
256255
distance = dc->distance_to_code(code);
257256

faiss/IndexRaBitQ.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ namespace faiss {
88
IndexRaBitQ::IndexRaBitQ() = default;
99

1010
IndexRaBitQ::IndexRaBitQ(idx_t d, MetricType metric)
11-
: IndexFlatCodes(0, d, metric), rabitq(d) {
11+
: IndexFlatCodes(0, d, metric), rabitq(d, metric) {
1212
code_size = rabitq.code_size;
1313

1414
is_trained = false;
@@ -47,17 +47,17 @@ void IndexRaBitQ::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
4747
}
4848

4949
FlatCodesDistanceComputer* IndexRaBitQ::get_FlatCodesDistanceComputer() const {
50-
RaBitQuantizer::RaBitDistanceComputer* dc =
51-
rabitq.get_distance_computer(qb, metric_type, center.data());
50+
FlatCodesDistanceComputer* dc =
51+
rabitq.get_distance_computer(qb, center.data());
5252
dc->code_size = rabitq.code_size;
5353
dc->codes = codes.data();
5454
return dc;
5555
}
5656

5757
FlatCodesDistanceComputer* IndexRaBitQ::get_quantized_distance_computer(
5858
const uint8_t qb) const {
59-
RaBitQuantizer::RaBitDistanceComputer* dc =
60-
rabitq.get_distance_computer(qb, metric_type, center.data());
59+
FlatCodesDistanceComputer* dc =
60+
rabitq.get_distance_computer(qb, center.data());
6161
dc->code_size = rabitq.code_size;
6262
dc->codes = codes.data();
6363
return dc;

0 commit comments

Comments
 (0)