diff --git a/c_api/IndexIVF_c_ex.cpp b/c_api/IndexIVF_c_ex.cpp index 7a3e79aeea..634bd499fb 100644 --- a/c_api/IndexIVF_c_ex.cpp +++ b/c_api/IndexIVF_c_ex.cpp @@ -88,10 +88,23 @@ int faiss_IndexIVF_compute_distance_to_codes_for_list( const float* x, idx_t n, const uint8_t* codes, - float* dists) { + float* dists, + float* dist_table) { try { reinterpret_cast(index)->compute_distance_to_codes_for_list( - list_no, x, n, codes, dists); + list_no, x, n, codes, dists, dist_table); + return 0; + } + CATCH_AND_HANDLE +} + +int faiss_IndexIVF_compute_distance_table( + FaissIndexIVF* index, + const float* x, + float* dist_table) { + try { + reinterpret_cast(index)->compute_distance_table( + x, dist_table); return 0; } CATCH_AND_HANDLE diff --git a/c_api/IndexIVF_c_ex.h b/c_api/IndexIVF_c_ex.h index 98d77bf836..76ac10b800 100644 --- a/c_api/IndexIVF_c_ex.h +++ b/c_api/IndexIVF_c_ex.h @@ -81,6 +81,7 @@ int faiss_IndexIVF_search_preassigned_with_params( @param n - number of codes @param codes - input codes @param dists - output computed distances + @param dist_table - input precomputed distance table for PQ */ int faiss_IndexIVF_compute_distance_to_codes_for_list( @@ -89,7 +90,8 @@ int faiss_IndexIVF_compute_distance_to_codes_for_list( const float* x, idx_t n, const uint8_t* codes, - float* dists); + float* dists, + float* dist_table); /* Given multiple vector IDs, retrieve the corresponding list (cluster) IDs @@ -108,6 +110,20 @@ int faiss_get_lists_for_keys( size_t n_keys, idx_t* lists); +/* + Given a query vector x, compute distance table and + return to the caller. + + @param x - input query vector + @param dist_table - output precomputed distance table for PQ + +*/ + +int faiss_IndexIVF_compute_distance_table( + FaissIndexIVF* index, + const float* x, + float* dist_table); + #ifdef __cplusplus } #endif diff --git a/faiss/IndexIVF.h b/faiss/IndexIVF.h index 306521b878..ebee506f22 100644 --- a/faiss/IndexIVF.h +++ b/faiss/IndexIVF.h @@ -459,6 +459,7 @@ struct IndexIVF : Index, IndexIVFInterface { * @param n - number of codes * @param codes - input codes * @param dists - output computed distances + * @param dist_table - input precomputed distance table for PQ */ virtual void compute_distance_to_codes_for_list( @@ -466,7 +467,20 @@ struct IndexIVF : Index, IndexIVFInterface { const float* x, idx_t n, const uint8_t* codes, - float* dists) const {}; + float* dists, + float* dist_table) const {}; + + /** Given a query vector x, compute distance table and + * return to the caller. + * + * @param x - input query vector + * @param dist_table - output precomputed distance table for PQ + * + */ + + virtual void compute_distance_table( + const float* x, + float* dist_table) const {}; IndexIVF(); diff --git a/faiss/IndexIVFPQ.cpp b/faiss/IndexIVFPQ.cpp index 1b7e27563c..7b8ad7a545 100644 --- a/faiss/IndexIVFPQ.cpp +++ b/faiss/IndexIVFPQ.cpp @@ -16,6 +16,7 @@ #include #include +#include #include #include @@ -565,6 +566,7 @@ struct QueryTables { } } + /***************************************************** * When inverted list is known: prepare computations *****************************************************/ @@ -748,6 +750,33 @@ struct QueryTables { return dis0; } + + + void init_sim_table(const float* qi, const float* table) { + this->qi = qi; + + if (metric_type == METRIC_INNER_PRODUCT) { + memcpy(sim_table, table, pq.ksub * pq.M * sizeof(float)); + } else { + if (!by_residual) { + memcpy(sim_table, table, pq.ksub * pq.M * sizeof(float)); + } else { + memcpy(sim_table_2, table, pq.ksub * pq.M * sizeof(float)); + } + } + } + + void copy_sim_table(float* table) const { + if (metric_type == METRIC_INNER_PRODUCT) { + memcpy(table, sim_table, pq.ksub * pq.M * sizeof(float)); + } else { + if (!by_residual) { + memcpy(table, sim_table, pq.ksub * pq.M * sizeof(float)); + } else { + memcpy(table, sim_table_2, pq.ksub * pq.M * sizeof(float)); + } + } + } }; // This way of handling the selector is not optimal since all distances @@ -1387,4 +1416,76 @@ size_t IndexIVFPQ::find_duplicates(idx_t* dup_ids, size_t* lims) const { return ngroup; } +void IndexIVFPQ::compute_distance_to_codes_for_list( + const idx_t list_no, + const float* x, + idx_t n, + const uint8_t* codes, + float* dists, + float* dist_table) const { + + std::unique_ptr scanner( + get_InvertedListScanner(true, nullptr)); + + + if (dist_table) { + if (auto* pqscanner = dynamic_cast(scanner.get())) { + pqscanner->init_sim_table(x, dist_table); + } + } else { + scanner->set_query(x); + } + + // Initialize distances with default values + std::vector dist_out(n, metric_type == METRIC_L2 ? HUGE_VAL : -HUGE_VAL); + + + //find the centroid corresponding to the input list_no + //and compute its distance from the query vector + std::vector centroid(d); + quantizer->reconstruct(list_no, centroid.data()); + + float coarse_dis = quantizer->metric_type == faiss::METRIC_L2 + ? faiss::fvec_L2sqr(x, centroid.data(), d) + : faiss::fvec_inner_product(x, centroid.data(), d); + + + scanner->set_list(list_no, coarse_dis); + + // Initialize ids_in as sequential numbers to allow mapping with output distances. + std::vector ids_in(n); + std::iota(ids_in.begin(), ids_in.end(), 0); + + //ids_out contain the order of distances in dist_out after scan_codes returns. + std::vector ids_out(n, 0); + + scanner->scan_codes(n, codes, ids_in.data(), dist_out.data(), ids_out.data(), n); + + // Reorder the returned distances in dist_out based on ids_out. + // This function needs to return the distances in the same order as input codes. + for (int j = 0; j < n; j++) { + int k = ids_out[j]; + dists[k] = dist_out[j]; + } + + return; +} + +//This function computes the distance table for the input vector x and returns it in dtable. +void IndexIVFPQ::compute_distance_table( + const float* x, + float* dist_table) const { + + std::unique_ptr scanner( + get_InvertedListScanner(true, nullptr)); + + scanner->set_query(x); + + if (auto* pqscanner = dynamic_cast(scanner.get())) { + pqscanner->copy_sim_table(dist_table); + } + + return; +} + } // namespace faiss diff --git a/faiss/IndexIVFPQ.h b/faiss/IndexIVFPQ.h index 7bf97ec0f3..b13c43b116 100644 --- a/faiss/IndexIVFPQ.h +++ b/faiss/IndexIVFPQ.h @@ -139,6 +139,18 @@ struct IndexIVFPQ : IndexIVF { /// build precomputed table void precompute_table(); + void compute_distance_to_codes_for_list( + const idx_t list_no, + const float* x, + idx_t n, + const uint8_t* codes, + float* dists, + float* dist_table) const override; + + void compute_distance_table( + const float* x, + float* dist_table) const override; + IndexIVFPQ(); }; diff --git a/faiss/IndexScalarQuantizer.cpp b/faiss/IndexScalarQuantizer.cpp index bd661e9765..8c013d0287 100644 --- a/faiss/IndexScalarQuantizer.cpp +++ b/faiss/IndexScalarQuantizer.cpp @@ -287,7 +287,8 @@ void IndexIVFScalarQuantizer::compute_distance_to_codes_for_list( const float* x, idx_t n, const uint8_t* codes, - float* dists) const { + float* dists, + float* dist_table) const { std::unique_ptr dc( sq.get_distance_computer(metric_type)); diff --git a/faiss/IndexScalarQuantizer.h b/faiss/IndexScalarQuantizer.h index c4ca5865c5..fe73536f6a 100644 --- a/faiss/IndexScalarQuantizer.h +++ b/faiss/IndexScalarQuantizer.h @@ -110,7 +110,8 @@ struct IndexIVFScalarQuantizer : IndexIVF { const float* x, idx_t n, const uint8_t* codes, - float* dists) const override; + float* dists, + float* dist_table) const override; };