@@ -817,8 +817,9 @@ struct RangeSearchResults {
817817 * The scanning functions call their favorite precompute_*
818818 * function to precompute the tables they need.
819819 *****************************************************/
820- template <typename IDType, MetricType METRIC_TYPE, class PQDecoder >
820+ template <typename IDType, MetricType METRIC_TYPE, class PQCodeDistance >
821821struct IVFPQScannerT : QueryTables {
822+ using PQDecoder = typename PQCodeDistance::PQDecoder;
822823 const uint8_t * list_codes;
823824 const IDType* list_ids;
824825 size_t list_size;
@@ -894,7 +895,7 @@ struct IVFPQScannerT : QueryTables {
894895 float distance_1 = 0 ;
895896 float distance_2 = 0 ;
896897 float distance_3 = 0 ;
897- distance_four_codes<PQDecoder> (
898+ PQCodeDistance:: distance_four_codes (
898899 pq.M ,
899900 pq.nbits ,
900901 sim_table,
@@ -917,7 +918,7 @@ struct IVFPQScannerT : QueryTables {
917918
918919 if (counter >= 1 ) {
919920 float dis = dis0 +
920- distance_single_code<PQDecoder> (
921+ PQCodeDistance:: distance_single_code (
921922 pq.M ,
922923 pq.nbits ,
923924 sim_table,
@@ -926,7 +927,7 @@ struct IVFPQScannerT : QueryTables {
926927 }
927928 if (counter >= 2 ) {
928929 float dis = dis0 +
929- distance_single_code<PQDecoder> (
930+ PQCodeDistance:: distance_single_code (
930931 pq.M ,
931932 pq.nbits ,
932933 sim_table,
@@ -935,7 +936,7 @@ struct IVFPQScannerT : QueryTables {
935936 }
936937 if (counter >= 3 ) {
937938 float dis = dis0 +
938- distance_single_code<PQDecoder> (
939+ PQCodeDistance:: distance_single_code (
939940 pq.M ,
940941 pq.nbits ,
941942 sim_table,
@@ -1101,7 +1102,7 @@ struct IVFPQScannerT : QueryTables {
11011102 float distance_1 = dis0;
11021103 float distance_2 = dis0;
11031104 float distance_3 = dis0;
1104- distance_four_codes<PQDecoder> (
1105+ PQCodeDistance:: distance_four_codes (
11051106 pq.M ,
11061107 pq.nbits ,
11071108 sim_table,
@@ -1132,7 +1133,7 @@ struct IVFPQScannerT : QueryTables {
11321133 n_hamming_pass++;
11331134
11341135 float dis = dis0 +
1135- distance_single_code<PQDecoder> (
1136+ PQCodeDistance:: distance_single_code (
11361137 pq.M ,
11371138 pq.nbits ,
11381139 sim_table,
@@ -1152,7 +1153,7 @@ struct IVFPQScannerT : QueryTables {
11521153 n_hamming_pass++;
11531154
11541155 float dis = dis0 +
1155- distance_single_code<PQDecoder> (
1156+ PQCodeDistance:: distance_single_code (
11561157 pq.M ,
11571158 pq.nbits ,
11581159 sim_table,
@@ -1199,8 +1200,8 @@ struct IVFPQScannerT : QueryTables {
11991200 *
12001201 * use_sel: store or ignore the IDSelector
12011202 */
1202- template <MetricType METRIC_TYPE, class C , class PQDecoder , bool use_sel>
1203- struct IVFPQScanner : IVFPQScannerT<idx_t , METRIC_TYPE, PQDecoder >,
1203+ template <MetricType METRIC_TYPE, class C , class PQCodeDistance , bool use_sel>
1204+ struct IVFPQScanner : IVFPQScannerT<idx_t , METRIC_TYPE, PQCodeDistance >,
12041205 InvertedListScanner {
12051206 int precompute_mode;
12061207 const IDSelector* sel;
@@ -1210,7 +1211,7 @@ struct IVFPQScanner : IVFPQScannerT<idx_t, METRIC_TYPE, PQDecoder>,
12101211 bool store_pairs,
12111212 int precompute_mode,
12121213 const IDSelector* sel)
1213- : IVFPQScannerT<idx_t , METRIC_TYPE, PQDecoder >(ivfpq, nullptr ),
1214+ : IVFPQScannerT<idx_t , METRIC_TYPE, PQCodeDistance >(ivfpq, nullptr ),
12141215 precompute_mode (precompute_mode),
12151216 sel (sel) {
12161217 this ->store_pairs = store_pairs;
@@ -1230,7 +1231,7 @@ struct IVFPQScanner : IVFPQScannerT<idx_t, METRIC_TYPE, PQDecoder>,
12301231 float distance_to_code (const uint8_t * code) const override {
12311232 assert (precompute_mode == 2 );
12321233 float dis = this ->dis0 +
1233- distance_single_code<PQDecoder> (
1234+ PQCodeDistance:: distance_single_code (
12341235 this ->pq .M , this ->pq .nbits , this ->sim_table , code);
12351236 return dis;
12361237 }
@@ -1294,7 +1295,9 @@ struct IVFPQScanner : IVFPQScannerT<idx_t, METRIC_TYPE, PQDecoder>,
12941295 }
12951296};
12961297
1297- template <class PQDecoder , bool use_sel>
1298+ /* * follow 3 stages of template dispatching */
1299+
1300+ template <class PQCodeDistance , bool use_sel>
12981301InvertedListScanner* get_InvertedListScanner1 (
12991302 const IndexIVFPQ& index,
13001303 bool store_pairs,
@@ -1303,32 +1306,47 @@ InvertedListScanner* get_InvertedListScanner1(
13031306 return new IVFPQScanner<
13041307 METRIC_INNER_PRODUCT,
13051308 CMin<float , idx_t >,
1306- PQDecoder ,
1309+ PQCodeDistance ,
13071310 use_sel>(index, store_pairs, 2 , sel);
13081311 } else if (index.metric_type == METRIC_L2) {
13091312 return new IVFPQScanner<
13101313 METRIC_L2,
13111314 CMax<float , idx_t >,
1312- PQDecoder ,
1315+ PQCodeDistance ,
13131316 use_sel>(index, store_pairs, 2 , sel);
13141317 }
13151318 return nullptr ;
13161319}
13171320
1318- template <bool use_sel>
1321+ template <bool use_sel, SIMDLevel SL >
13191322InvertedListScanner* get_InvertedListScanner2 (
13201323 const IndexIVFPQ& index,
13211324 bool store_pairs,
13221325 const IDSelector* sel) {
13231326 if (index.pq .nbits == 8 ) {
1324- return get_InvertedListScanner1<PQDecoder8, use_sel>(
1325- index, store_pairs, sel);
1327+ return get_InvertedListScanner1<
1328+ PQCodeDistance<PQDecoder8, SL>,
1329+ use_sel>(index, store_pairs, sel);
13261330 } else if (index.pq .nbits == 16 ) {
1327- return get_InvertedListScanner1<PQDecoder16, use_sel>(
1328- index, store_pairs, sel);
1331+ return get_InvertedListScanner1<
1332+ PQCodeDistance<PQDecoder16, SL>,
1333+ use_sel>(index, store_pairs, sel);
1334+ } else {
1335+ return get_InvertedListScanner1<
1336+ PQCodeDistance<PQDecoderGeneric, SL>,
1337+ use_sel>(index, store_pairs, sel);
1338+ }
1339+ }
1340+
1341+ template <SIMDLevel SL>
1342+ InvertedListScanner* get_InvertedListScanner3 (
1343+ const IndexIVFPQ& index,
1344+ bool store_pairs,
1345+ const IDSelector* sel) {
1346+ if (sel) {
1347+ return get_InvertedListScanner2<true , SL>(index, store_pairs, sel);
13291348 } else {
1330- return get_InvertedListScanner1<PQDecoderGeneric, use_sel>(
1331- index, store_pairs, sel);
1349+ return get_InvertedListScanner2<false , SL>(index, store_pairs, sel);
13321350 }
13331351}
13341352
@@ -1338,11 +1356,7 @@ InvertedListScanner* IndexIVFPQ::get_InvertedListScanner(
13381356 bool store_pairs,
13391357 const IDSelector* sel,
13401358 const IVFSearchParameters*) const {
1341- if (sel) {
1342- return get_InvertedListScanner2<true >(*this , store_pairs, sel);
1343- } else {
1344- return get_InvertedListScanner2<false >(*this , store_pairs, sel);
1345- }
1359+ DISPATCH_SIMDLevel (get_InvertedListScanner3, *this , store_pairs, sel);
13461360 return nullptr ;
13471361}
13481362
0 commit comments