Skip to content

Commit 32fb3de

Browse files
mdouzemeta-codesync[bot]
authored andcommitted
moved IndexIVFPQ and IndexPQ to dynamic dispatch (facebookresearch#4291)
Summary: Pull Request resolved: facebookresearch#4291 moved IndexIVFPQ and IndexPQ to dynamic dispatch. Since the code was already quite modular (thanks Alex!), this boils down to make independent cpp files for the different SIMD versions. Differential Revision: D72937709
1 parent 1b11fa7 commit 32fb3de

17 files changed

Lines changed: 896 additions & 1094 deletions

faiss/IndexIVFPQ.cpp

Lines changed: 41 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -817,8 +817,9 @@ struct RangeSearchResults {
817817
* The scanning functions call their favorite precompute_*
818818
* function to precompute the tables they need.
819819
*****************************************************/
820-
template <typename IDType, MetricType METRIC_TYPE, class PQDecoder>
820+
template <typename IDType, MetricType METRIC_TYPE, class PQCodeDistance>
821821
struct IVFPQScannerT : QueryTables {
822+
using PQDecoder = typename PQCodeDistance::PQDecoder;
822823
const uint8_t* list_codes;
823824
const IDType* list_ids;
824825
size_t list_size;
@@ -894,7 +895,7 @@ struct IVFPQScannerT : QueryTables {
894895
float distance_1 = 0;
895896
float distance_2 = 0;
896897
float distance_3 = 0;
897-
distance_four_codes<PQDecoder>(
898+
PQCodeDistance::distance_four_codes(
898899
pq.M,
899900
pq.nbits,
900901
sim_table,
@@ -917,7 +918,7 @@ struct IVFPQScannerT : QueryTables {
917918

918919
if (counter >= 1) {
919920
float dis = dis0 +
920-
distance_single_code<PQDecoder>(
921+
PQCodeDistance::distance_single_code(
921922
pq.M,
922923
pq.nbits,
923924
sim_table,
@@ -926,7 +927,7 @@ struct IVFPQScannerT : QueryTables {
926927
}
927928
if (counter >= 2) {
928929
float dis = dis0 +
929-
distance_single_code<PQDecoder>(
930+
PQCodeDistance::distance_single_code(
930931
pq.M,
931932
pq.nbits,
932933
sim_table,
@@ -935,7 +936,7 @@ struct IVFPQScannerT : QueryTables {
935936
}
936937
if (counter >= 3) {
937938
float dis = dis0 +
938-
distance_single_code<PQDecoder>(
939+
PQCodeDistance::distance_single_code(
939940
pq.M,
940941
pq.nbits,
941942
sim_table,
@@ -1101,7 +1102,7 @@ struct IVFPQScannerT : QueryTables {
11011102
float distance_1 = dis0;
11021103
float distance_2 = dis0;
11031104
float distance_3 = dis0;
1104-
distance_four_codes<PQDecoder>(
1105+
PQCodeDistance::distance_four_codes(
11051106
pq.M,
11061107
pq.nbits,
11071108
sim_table,
@@ -1132,7 +1133,7 @@ struct IVFPQScannerT : QueryTables {
11321133
n_hamming_pass++;
11331134

11341135
float dis = dis0 +
1135-
distance_single_code<PQDecoder>(
1136+
PQCodeDistance::distance_single_code(
11361137
pq.M,
11371138
pq.nbits,
11381139
sim_table,
@@ -1152,7 +1153,7 @@ struct IVFPQScannerT : QueryTables {
11521153
n_hamming_pass++;
11531154

11541155
float dis = dis0 +
1155-
distance_single_code<PQDecoder>(
1156+
PQCodeDistance::distance_single_code(
11561157
pq.M,
11571158
pq.nbits,
11581159
sim_table,
@@ -1199,8 +1200,8 @@ struct IVFPQScannerT : QueryTables {
11991200
*
12001201
* use_sel: store or ignore the IDSelector
12011202
*/
1202-
template <MetricType METRIC_TYPE, class C, class PQDecoder, bool use_sel>
1203-
struct IVFPQScanner : IVFPQScannerT<idx_t, METRIC_TYPE, PQDecoder>,
1203+
template <MetricType METRIC_TYPE, class C, class PQCodeDistance, bool use_sel>
1204+
struct IVFPQScanner : IVFPQScannerT<idx_t, METRIC_TYPE, PQCodeDistance>,
12041205
InvertedListScanner {
12051206
int precompute_mode;
12061207
const IDSelector* sel;
@@ -1210,7 +1211,7 @@ struct IVFPQScanner : IVFPQScannerT<idx_t, METRIC_TYPE, PQDecoder>,
12101211
bool store_pairs,
12111212
int precompute_mode,
12121213
const IDSelector* sel)
1213-
: IVFPQScannerT<idx_t, METRIC_TYPE, PQDecoder>(ivfpq, nullptr),
1214+
: IVFPQScannerT<idx_t, METRIC_TYPE, PQCodeDistance>(ivfpq, nullptr),
12141215
precompute_mode(precompute_mode),
12151216
sel(sel) {
12161217
this->store_pairs = store_pairs;
@@ -1230,7 +1231,7 @@ struct IVFPQScanner : IVFPQScannerT<idx_t, METRIC_TYPE, PQDecoder>,
12301231
float distance_to_code(const uint8_t* code) const override {
12311232
assert(precompute_mode == 2);
12321233
float dis = this->dis0 +
1233-
distance_single_code<PQDecoder>(
1234+
PQCodeDistance::distance_single_code(
12341235
this->pq.M, this->pq.nbits, this->sim_table, code);
12351236
return dis;
12361237
}
@@ -1294,7 +1295,9 @@ struct IVFPQScanner : IVFPQScannerT<idx_t, METRIC_TYPE, PQDecoder>,
12941295
}
12951296
};
12961297

1297-
template <class PQDecoder, bool use_sel>
1298+
/** follow 3 stages of template dispatching */
1299+
1300+
template <class PQCodeDistance, bool use_sel>
12981301
InvertedListScanner* get_InvertedListScanner1(
12991302
const IndexIVFPQ& index,
13001303
bool store_pairs,
@@ -1303,32 +1306,47 @@ InvertedListScanner* get_InvertedListScanner1(
13031306
return new IVFPQScanner<
13041307
METRIC_INNER_PRODUCT,
13051308
CMin<float, idx_t>,
1306-
PQDecoder,
1309+
PQCodeDistance,
13071310
use_sel>(index, store_pairs, 2, sel);
13081311
} else if (index.metric_type == METRIC_L2) {
13091312
return new IVFPQScanner<
13101313
METRIC_L2,
13111314
CMax<float, idx_t>,
1312-
PQDecoder,
1315+
PQCodeDistance,
13131316
use_sel>(index, store_pairs, 2, sel);
13141317
}
13151318
return nullptr;
13161319
}
13171320

1318-
template <bool use_sel>
1321+
template <bool use_sel, SIMDLevel SL>
13191322
InvertedListScanner* get_InvertedListScanner2(
13201323
const IndexIVFPQ& index,
13211324
bool store_pairs,
13221325
const IDSelector* sel) {
13231326
if (index.pq.nbits == 8) {
1324-
return get_InvertedListScanner1<PQDecoder8, use_sel>(
1325-
index, store_pairs, sel);
1327+
return get_InvertedListScanner1<
1328+
PQCodeDistance<PQDecoder8, SL>,
1329+
use_sel>(index, store_pairs, sel);
13261330
} else if (index.pq.nbits == 16) {
1327-
return get_InvertedListScanner1<PQDecoder16, use_sel>(
1328-
index, store_pairs, sel);
1331+
return get_InvertedListScanner1<
1332+
PQCodeDistance<PQDecoder16, SL>,
1333+
use_sel>(index, store_pairs, sel);
1334+
} else {
1335+
return get_InvertedListScanner1<
1336+
PQCodeDistance<PQDecoderGeneric, SL>,
1337+
use_sel>(index, store_pairs, sel);
1338+
}
1339+
}
1340+
1341+
template <SIMDLevel SL>
1342+
InvertedListScanner* get_InvertedListScanner3(
1343+
const IndexIVFPQ& index,
1344+
bool store_pairs,
1345+
const IDSelector* sel) {
1346+
if (sel) {
1347+
return get_InvertedListScanner2<true, SL>(index, store_pairs, sel);
13291348
} else {
1330-
return get_InvertedListScanner1<PQDecoderGeneric, use_sel>(
1331-
index, store_pairs, sel);
1349+
return get_InvertedListScanner2<false, SL>(index, store_pairs, sel);
13321350
}
13331351
}
13341352

@@ -1338,11 +1356,7 @@ InvertedListScanner* IndexIVFPQ::get_InvertedListScanner(
13381356
bool store_pairs,
13391357
const IDSelector* sel,
13401358
const IVFSearchParameters*) const {
1341-
if (sel) {
1342-
return get_InvertedListScanner2<true>(*this, store_pairs, sel);
1343-
} else {
1344-
return get_InvertedListScanner2<false>(*this, store_pairs, sel);
1345-
}
1359+
DISPATCH_SIMDLevel(get_InvertedListScanner3, *this, store_pairs, sel);
13461360
return nullptr;
13471361
}
13481362

faiss/IndexPQ.cpp

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ void IndexPQ::train(idx_t n, const float* x) {
7272

7373
namespace {
7474

75-
template <class PQDecoder>
75+
template <class PQCodeDistance>
7676
struct PQDistanceComputer : FlatCodesDistanceComputer {
7777
size_t d;
7878
MetricType metric;
@@ -86,7 +86,7 @@ struct PQDistanceComputer : FlatCodesDistanceComputer {
8686
float distance_to_code(const uint8_t* code) final {
8787
ndis++;
8888

89-
float dis = distance_single_code<PQDecoder>(
89+
float dis = PQCodeDistance::distance_single_code(
9090
pq.M, pq.nbits, precomputed_table.data(), code);
9191
return dis;
9292
}
@@ -95,8 +95,10 @@ struct PQDistanceComputer : FlatCodesDistanceComputer {
9595
FAISS_THROW_IF_NOT(sdc);
9696
const float* sdci = sdc;
9797
float accu = 0;
98-
PQDecoder codei(codes + i * code_size, pq.nbits);
99-
PQDecoder codej(codes + j * code_size, pq.nbits);
98+
typename PQCodeDistance::PQDecoder codei(
99+
codes + i * code_size, pq.nbits);
100+
typename PQCodeDistance::PQDecoder codej(
101+
codes + j * code_size, pq.nbits);
100102

101103
for (int l = 0; l < pq.M; l++) {
102104
accu += sdci[codei.decode() + (codej.decode() << codei.nbits)];
@@ -134,16 +136,24 @@ struct PQDistanceComputer : FlatCodesDistanceComputer {
134136
}
135137
};
136138

139+
template <SIMDLevel SL>
140+
FlatCodesDistanceComputer* get_FlatCodesDistanceComputer1(
141+
const IndexPQ& index) {
142+
int nbits = index.pq.nbits;
143+
if (nbits == 8) {
144+
return new PQDistanceComputer<PQCodeDistance<PQDecoder8, SL>>(index);
145+
} else if (nbits == 16) {
146+
return new PQDistanceComputer<PQCodeDistance<PQDecoder16, SL>>(index);
147+
} else {
148+
return new PQDistanceComputer<PQCodeDistance<PQDecoderGeneric, SL>>(
149+
index);
150+
}
151+
}
152+
137153
} // namespace
138154

139155
FlatCodesDistanceComputer* IndexPQ::get_FlatCodesDistanceComputer() const {
140-
if (pq.nbits == 8) {
141-
return new PQDistanceComputer<PQDecoder8>(*this);
142-
} else if (pq.nbits == 16) {
143-
return new PQDistanceComputer<PQDecoder16>(*this);
144-
} else {
145-
return new PQDistanceComputer<PQDecoderGeneric>(*this);
146-
}
156+
DISPATCH_SIMDLevel(get_FlatCodesDistanceComputer1, *this);
147157
}
148158

149159
/*****************************************

0 commit comments

Comments
 (0)